Source code for hwoutils.radial

"""Radial utilities for image analysis.

Functions for computing radial distance grids and extracting radial profiles
from 2D images. All functions are JIT-compilable and differentiable.
"""

import jax
import jax.numpy as jnp


[docs] def radial_distance( shape: tuple[int, int], center: tuple[float, float] | None = None, ) -> jax.Array: """Calculate radial distance from center for each pixel. Args: shape: Image shape (ny, nx). center: Center coordinates (cy, cx). If None, uses geometric center ``((ny-1)/2, (nx-1)/2)``. Returns: 2D array of radial distances in pixels. """ ny, nx = shape if center is None: center = ((ny - 1) / 2.0, (nx - 1) / 2.0) cy, cx = center y, x = jnp.ogrid[:ny, :nx] return jnp.sqrt((y - cy) ** 2 + (x - cx) ** 2)
[docs] @jax.jit(static_argnames=["nbins"]) def radial_profile( image: jax.Array, pixel_scale_arcsec: float = 1.0, center: tuple[float, float] | None = None, nbins: int | None = None, ) -> tuple[jax.Array, jax.Array]: """Compute the radial profile of a 2D image. Bins pixels by their distance from center and computes the mean value in each radial bin. Args: image: 2D input image. pixel_scale_arcsec: Conversion factor from pixels to physical units (e.g. λ/D per pixel). Default 1.0 gives bins in pixels. center: Center coordinates (cy, cx). If None, uses geometric center. nbins: Number of radial bins. If None, uses ``floor(max_dim / 2)``. Returns: ``(separations, profile)`` where separations are bin centers in physical units and profile is the mean value in each bin. """ ny, nx = image.shape if nbins is None: nbins = int(max(ny, nx) // 2) r = radial_distance((ny, nx), center) max_radius = jnp.max(r) bin_edges = jnp.linspace(0, max_radius, nbins + 1) bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2.0 r_flat = r.ravel() image_flat = image.ravel() # Assign each pixel to a bin (1-indexed; 0 = below first edge) inds = jnp.digitize(r_flat, bin_edges) # Clamp overflow into last bin inds = jnp.clip(inds, 1, nbins) # Compute mean per bin using scatter bin_sums = jnp.zeros(nbins).at[inds - 1].add(image_flat) bin_counts = jnp.zeros(nbins).at[inds - 1].add(1.0) profile = jnp.where(bin_counts > 0, bin_sums / bin_counts, 0.0) separations = bin_centers * pixel_scale_arcsec return separations, profile