"""Image transformation utilities.
Flux-conserving resampling and sub-pixel image operations. All functions
are JIT-compilable and differentiable.
"""
import functools
import jax
import jax.numpy as jnp
from hwoutils.map_coordinates import map_coordinates
[docs]
def ccw_rotation_matrix(rotation_deg: float) -> jax.Array:
"""Return the counter-clockwise rotation matrix for a given angle.
Args:
rotation_deg: Rotation angle in degrees. Positive = counter-clockwise.
Returns:
2x2 rotation matrix as a JAX array.
"""
theta = jnp.deg2rad(rotation_deg)
cos_theta = jnp.cos(theta)
sin_theta = jnp.sin(theta)
return jnp.array(
[
[cos_theta, -sin_theta],
[sin_theta, cos_theta],
]
)
[docs]
@functools.partial(jax.jit, static_argnames=["order", "mode"])
def shift_image(
image: jax.Array,
shift_y: float,
shift_x: float,
order: int = 3,
mode: str = "constant",
cval: float = 0.0,
) -> jax.Array:
"""Shift an image with sub-pixel precision.
Uses inverse mapping: to shift content by (+dy, +dx), sample from
(y-dy, x-dx).
Args:
image: 2D input image.
shift_y: Shift in Y direction (pixels). Positive = Down.
shift_x: Shift in X direction (pixels). Positive = Right.
order: Interpolation order passed to ``map_coordinates``. Default
is 3, which uses the Keys cubic convolution kernel (see
``docs/interpolation.md``).
mode: Boundary handling mode.
cval: Value for 'constant' mode outside boundaries.
Returns:
Shifted image with same shape as input.
"""
ny, nx = image.shape
y_grid, x_grid = jnp.mgrid[:ny, :nx]
coords = [y_grid - shift_y, x_grid - shift_x]
return map_coordinates(image, coords, order=order, mode=mode, cval=cval)
[docs]
@functools.partial(jax.jit, static_argnames=["shape_tgt", "order"])
def resample_flux(
f_src: jax.Array,
pixscale_src: float,
pixscale_tgt: float,
shape_tgt: tuple[int, int],
rotation_deg: float = 0.0,
order: int = 3,
) -> jax.Array:
"""Resample an image onto a new grid while conserving total flux.
Performs an affine transformation (rotation and scaling) to map
the source image onto a target grid. Converts to surface brightness,
interpolates, then converts back to integrated flux per pixel.
Args:
f_src: Source image (2D) with integrated flux per pixel.
pixscale_src: Pixel scale of source image.
pixscale_tgt: Pixel scale of target image (same units as src).
shape_tgt: Target shape (ny_tgt, nx_tgt).
rotation_deg: CCW rotation angle in degrees.
order: Interpolation order passed to ``map_coordinates``. Default
is 3, which uses the Keys cubic convolution kernel -- a true
interpolant with partition of unity at integer grid spacing
that conserves flux on integer downsampling of band-limited
inputs. See ``docs/interpolation.md``.
Returns:
Resampled image with total flux conserved. Shape: (ny_tgt, nx_tgt).
"""
ny_src, nx_src = f_src.shape
ny_tgt, nx_tgt = shape_tgt
# Surface brightness (flux per unit area)
s_src = f_src / (pixscale_src**2)
# Affine matrix (TARGET pixel centres -> SOURCE coordinates)
scale = pixscale_tgt / pixscale_src
a_mat = ccw_rotation_matrix(rotation_deg) * scale
c_src = jnp.array([(ny_src - 1) / 2.0, (nx_src - 1) / 2.0])
c_tgt = jnp.array([(ny_tgt - 1) / 2.0, (nx_tgt - 1) / 2.0])
offset = c_src - a_mat @ c_tgt
# Grid of TARGET pixel centres
y_coords = jnp.arange(ny_tgt)
x_coords = jnp.arange(nx_tgt)
y_tgt, x_tgt = jnp.meshgrid(y_coords, x_coords, indexing="ij")
# (2, ny_tgt, nx_tgt)
coords = jnp.stack([y_tgt, x_tgt], axis=0)
coords_src = (a_mat @ coords.reshape(2, -1) + offset[:, None]).reshape(coords.shape)
# Interpolate surface brightness
s_tgt = map_coordinates(
s_src, [coords_src[0], coords_src[1]], order=order, mode="constant", cval=0.0
)
# Back to integrated flux per target pixel
return s_tgt * (pixscale_tgt**2)
# ---------------------------------------------------------------------------
# PSF Downsampling
# ---------------------------------------------------------------------------
[docs]
def downsample_psf(
psf: jax.Array,
src_pixscale: float,
target_shape: tuple[int, int],
) -> tuple[jax.Array, float]:
"""Downsample a PSF to target shape while conserving total flux.
Args:
psf: The source PSF image (2D array).
src_pixscale: The pixel scale of the source PSF (in lambda/D or
other consistent units).
target_shape: The target shape (ny_tgt, nx_tgt).
Returns:
Tuple of (resampled_psf, new_pixscale).
"""
ny_src = psf.shape[0]
ny_tgt = target_shape[0]
scale_factor = ny_src / ny_tgt
tgt_pixscale = src_pixscale * scale_factor
resampled = resample_flux(
psf,
src_pixscale,
tgt_pixscale,
target_shape,
rotation_deg=0.0,
)
return resampled, tgt_pixscale
[docs]
def downsample_psfs(
psfs: jax.Array,
src_pixscale: float,
target_shape: tuple[int, int],
) -> tuple[jax.Array, float]:
"""Downsample a stack of PSFs to target shape while conserving total flux.
Args:
psfs: Stack of PSF images with shape (N, H, W).
src_pixscale: The pixel scale of the source PSFs.
target_shape: The target shape (ny_tgt, nx_tgt) for each PSF.
Returns:
Tuple of (resampled_psfs, new_pixscale).
"""
ny_tgt = target_shape[0]
scale_factor = psfs.shape[1] / ny_tgt
tgt_pixscale = src_pixscale * scale_factor
def resample_single(psf):
return resample_flux(
psf,
src_pixscale,
tgt_pixscale,
target_shape,
rotation_deg=0.0,
)
resample_batch = jax.vmap(resample_single)
return resample_batch(psfs), tgt_pixscale