Source code for hwoutils.fft

"""FFT-based sub-pixel image shifting.

Provides Fourier shift primitives for sub-pixel image translation. The JAX
versions (fft_shift_x, fft_shift_y) accept precomputed phasors for efficient
repeated shifts. The NumPy versions (fft_shift, fft_shift_1d) are standalone.

All functions operate on 2D images via separable 1D FFTs along each axis,
which is O(2N * N log N) vs O(N^2 log N^2) for a full 2D FFT.
"""

import jax.numpy as jnp
import numpy as np
from jax import lax

# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------


[docs] def get_pad_info(image, pad_factor): """Compute padding sizes for FFT shift operations. Args: image: 2D input image (JAX or NumPy array). pad_factor: Factor by which to pad (e.g. 1.5 gives 50% on each side). Returns: Tuple of (n_pixels_orig, n_pad, img_edge, n_pixels_final). """ n_pixels_orig = image.shape[0] n_pad = int(pad_factor * n_pixels_orig) img_edge = n_pad + n_pixels_orig n_pixels_final = int(2 * n_pixels_orig * pad_factor + n_pixels_orig) return n_pixels_orig, n_pad, img_edge, n_pixels_final
# --------------------------------------------------------------------------- # JAX versions (JIT-compatible, require precomputed phasors) # ---------------------------------------------------------------------------
[docs] def fft_shift_x(image, shift_pixels, phasor, clamp=True): """Apply a Fourier shift along the x-axis (JAX, JIT-compatible). Uses a precomputed phasor for efficient repeated shifts of images with the same shape. Args: image: 2D input image (JAX array). shift_pixels: Sub-pixel shift amount along x. phasor: Precomputed exp(-2j * pi * fft_freqs) for the padded size. clamp: If True, clamp negative values to zero after shift. Returns: Shifted image with same shape as input. """ _n_pixels_orig, n_pad, img_edge, _n_pixels_final = get_pad_info(image, 1.5) pad_val = jnp.zeros((), dtype=image.dtype) padded = lax.pad(image, pad_val, [(n_pad, n_pad, 0), (n_pad, n_pad, 0)]) padded = jnp.fft.fft(padded, axis=1) phasor = jnp.tile(phasor**shift_pixels, (padded.shape[0], 1)) padded = padded * phasor padded = jnp.real(jnp.fft.ifft(padded, axis=1)) image = padded[n_pad:img_edge, n_pad:img_edge] if clamp: return jnp.maximum(image, 0.0) return image
[docs] def fft_shift_y(image, shift_pixels, phasor, clamp=True): """Apply a Fourier shift along the y-axis (JAX, JIT-compatible). Uses a precomputed phasor for efficient repeated shifts of images with the same shape. Args: image: 2D input image (JAX array). shift_pixels: Sub-pixel shift amount along y. phasor: Precomputed exp(-2j * pi * fft_freqs) for the padded size. clamp: If True, clamp negative values to zero after shift. Returns: Shifted image with same shape as input. """ _n_pixels_orig, n_pad, img_edge, _n_pixels_final = get_pad_info(image, 1.5) pad_val = jnp.zeros((), dtype=image.dtype) padded = lax.pad(image, pad_val, [(n_pad, n_pad, 0), (n_pad, n_pad, 0)]) padded = jnp.fft.fft(padded, axis=0) phasor = jnp.tile(phasor**shift_pixels, (padded.shape[1], 1)).T padded = padded * phasor padded = jnp.real(jnp.fft.ifft(padded, axis=0)) image = padded[n_pad:img_edge, n_pad:img_edge] if clamp: return jnp.maximum(image, 0.0) return image
# --------------------------------------------------------------------------- # NumPy versions (standalone, no precomputed phasors needed) # ---------------------------------------------------------------------------
[docs] def fft_shift_1d(image, shift_pixels, axis): """Apply a Fourier shift along a specified axis (NumPy). Pads, applies a 1D FFT phasor shift, and unpads. Standalone version that computes its own phasor internally. Args: image: 2D input image (NumPy array). shift_pixels: Sub-pixel shift amount. axis: Axis to shift (0 for vertical/y, 1 for horizontal/x). Returns: Shifted image with same shape as input. """ n_pixels = image.shape[0] n_pad = int(1.5 * n_pixels) img_edge = n_pad + n_pixels padded = np.pad(image, n_pad, mode="constant") padded = np.fft.fft(padded, axis=axis) freqs = np.fft.fftfreq(4 * n_pixels) phasor = np.exp(-2j * np.pi * freqs * shift_pixels) if axis == 1: phasor = np.tile(phasor, (padded.shape[0], 1)) else: phasor = np.tile(phasor, (padded.shape[1], 1)).T padded = padded * phasor padded = np.real(np.fft.ifft(padded, axis=axis)) return padded[n_pad:img_edge, n_pad:img_edge]
[docs] def fft_shift(image, x=0, y=0): """Apply Fourier shifts along x and/or y axes (NumPy). Convenience wrapper that calls fft_shift_1d for each non-zero axis. Args: image: 2D input image (NumPy array). x: Sub-pixel shift along x-axis. y: Sub-pixel shift along y-axis. Returns: Shifted image with same shape as input. Raises: AssertionError: If both x and y are zero. """ assert x != 0 or y != 0, "One of x or y must be non-zero." if x != 0: image = fft_shift_1d(image, x, axis=1) if y != 0: image = fft_shift_1d(image, y, axis=0) return image