hwoutils#

hwoutils – Shared JAX-based utilities for the HWO simulation suite.

Submodules#

Attributes#

Functions#

enable_x64([use_x64])

Enable 64-bit floating-point precision in JAX.

set_host_device_count(n)

Expose n CPU cores as separate XLA devices.

set_platform([platform])

Select the JAX compute platform (cpu, gpu, or tpu).

Package Contents#

hwoutils.__version__ = 'unknown'#
hwoutils.enable_x64(use_x64=True)[source]#

Enable 64-bit floating-point precision in JAX.

By default JAX uses 32-bit precision. Call this before any array operations to switch to 64-bit (NumPy-compatible) precision.

Args:
use_x64: When True, JAX arrays use 64-bit precision.

Falls back to the JAX_ENABLE_X64 environment variable when False.

Parameters:

use_x64 (bool)

Return type:

None

hwoutils.set_host_device_count(n)[source]#

Expose n CPU cores as separate XLA devices.

By default XLA treats all CPU cores as one device. This function sets the XLA_FLAGS environment variable so that jax.pmap() can distribute work across n host devices.

Must be called before any JAX computation.

Args:

n: Number of CPU devices to expose.

Parameters:

n (int)

Return type:

None

hwoutils.set_platform(platform=None)[source]#

Select the JAX compute platform (cpu, gpu, or tpu).

Must be called before any JAX computation.

Args:
platform: One of 'cpu', 'gpu', or 'tpu'.

Defaults to the JAX_PLATFORMS environment variable, or 'cpu' if unset.

Parameters:

platform (str | None)

Return type:

None