hwoutils.jax_config#
JAX runtime configuration helpers.
Centralizes JAX platform, precision, and device-count settings so that every library in the workspace uses the same API and stays current with upstream deprecations.
Functions#
|
Enable 64-bit floating-point precision in JAX. |
|
Select the JAX compute platform (cpu, gpu, or tpu). |
Expose n CPU cores as separate XLA devices. |
Module Contents#
- hwoutils.jax_config.enable_x64(use_x64=True)[source]#
Enable 64-bit floating-point precision in JAX.
By default JAX uses 32-bit precision. Call this before any array operations to switch to 64-bit (NumPy-compatible) precision.
- Args:
- use_x64: When True, JAX arrays use 64-bit precision.
Falls back to the
JAX_ENABLE_X64environment variable when False.
- Parameters:
use_x64 (bool)
- Return type:
None
- hwoutils.jax_config.set_platform(platform=None)[source]#
Select the JAX compute platform (cpu, gpu, or tpu).
Must be called before any JAX computation.
- Args:
- platform: One of
'cpu','gpu', or'tpu'. Defaults to the
JAX_PLATFORMSenvironment variable, or'cpu'if unset.
- platform: One of
- Parameters:
platform (str | None)
- Return type:
None
- hwoutils.jax_config.set_host_device_count(n)[source]#
Expose n CPU cores as separate XLA devices.
By default XLA treats all CPU cores as one device. This function sets the
XLA_FLAGSenvironment variable so thatjax.pmap()can distribute work across n host devices.Must be called before any JAX computation.
- Args:
n: Number of CPU devices to expose.
- Parameters:
n (int)
- Return type:
None