JAX Configuration Guide#
hwoutils provides three functions for configuring the JAX runtime.
All three are re-exported from the top-level hwoutils namespace.
from hwoutils import enable_x64, set_platform, set_host_device_count
The Golden Rule: Configure Before Computing#
JAX initializes its backends on the first JAX operation, which can
happen as early as the first import that triggers array creation.
Configuration set after initialization is either silently ignored or
raises a RuntimeError.
# CORRECT -- configure immediately after import
import jax
from hwoutils import enable_x64, set_platform
enable_x64()
set_platform("cpu")
import jax.numpy as jnp # safe: config is already set
x = jnp.ones(10) # uses float64 on CPU
# WRONG -- too late, JAX already initialized
import jax.numpy as jnp
x = jnp.ones(10) # triggers backend init with float32
from hwoutils import enable_x64
enable_x64() # may be silently ignored!
[!CAUTION] Library imports can trigger JAX initialization. If your project imports a JAX-based library (orbix, coronagraphoto, etc.) before calling
enable_x64(), the flag may not take effect. Always configure JAX at the very top of your entry point.
Environment variables (safest approach)#
Setting environment variables guarantees the configuration is active before Python even starts:
JAX_ENABLE_X64=True JAX_PLATFORMS=cpu python my_script.py
Or in a shell profile / CI config:
export JAX_ENABLE_X64=True
export JAX_PLATFORMS=cpu
enable_x64#
Switches JAX from 32-bit (default) to 64-bit floating-point precision.
from hwoutils import enable_x64
enable_x64() # enable float64
enable_x64(False) # revert to float32 (or read JAX_ENABLE_X64 env var)
The jax.enable_x64() context manager (JAX >= 0.8.0)#
As of JAX 0.8.0 (Oct 2025), there is a built-in context manager for temporarily enabling 64-bit precision:
import jax
import jax.numpy as jnp
# Default 32-bit
x = jnp.ones(3) # float32
with jax.enable_x64():
y = jnp.ones(3) # float64 inside the block
z = jnp.ones(3) # back to float32
This is useful for tests or isolated calculations that need double precision without affecting the rest of the program.
[!NOTE] The context manager replaces the deprecated
jax.experimental.enable_x64(). If you see the experimental version in old code, update it tojax.enable_x64().
When to use which#
Scenario |
Approach |
|---|---|
Entire program needs float64 |
|
One test or function needs float64 |
|
CI / batch jobs |
|
set_platform#
Selects the compute backend (CPU, GPU, or TPU).
from hwoutils import set_platform
set_platform("cpu") # force CPU even if GPU is available
set_platform("gpu") # use GPU
set_platform() # read JAX_PLATFORMS env var, default "cpu"
[!IMPORTANT] This uses the modern
jax_platformsconfig key. Older code usingjax_platform_name(singular) will trigger a deprecation warning in recent JAX versions.
Common gotcha: “No GPU/TPU found”#
If you call set_platform("gpu") but no GPU is available, JAX will
raise an error at the first computation rather than falling back to
CPU. To allow fallback:
# Allow GPU with CPU fallback
export JAX_PLATFORMS=gpu,cpu
Or in code:
set_platform("gpu,cpu")
set_host_device_count#
Exposes multiple CPU cores as separate XLA devices, enabling
jax.pmap on CPU:
from hwoutils import set_host_device_count
set_host_device_count(4) # expose 4 CPU devices
[!WARNING] This must be called before any JAX operation. It works by setting the
XLA_FLAGSenvironment variable, which XLA reads only once during initialization.
Quick Reference#
# Standard preamble for scripts in this workspace
import jax
from hwoutils import enable_x64, set_platform
enable_x64()
set_platform("cpu")
# Now safe to import JAX-based libraries and do work
import jax.numpy as jnp
from yippy import Coronagraph
from orbix import ...
Debugging checklist#
Symptom |
Likely cause |
|---|---|
Arrays are float32 despite |
Called too late (after first JAX op) |
|
|
|
|
|
Old code using deprecated key, switch to |
|
Replace with |