hwoutils.jax_config
===================

.. py:module:: hwoutils.jax_config

.. autoapi-nested-parse::

   JAX runtime configuration helpers.

   Centralizes JAX platform, precision, and device-count settings so that
   every library in the workspace uses the same API and stays current with
   upstream deprecations.



Functions
---------

.. autoapisummary::

   hwoutils.jax_config.enable_x64
   hwoutils.jax_config.set_platform
   hwoutils.jax_config.set_host_device_count


Module Contents
---------------

.. py:function:: enable_x64(use_x64 = True)

   Enable 64-bit floating-point precision in JAX.

   By default JAX uses 32-bit precision. Call this before any array
   operations to switch to 64-bit (NumPy-compatible) precision.

   Args:
       use_x64: When True, JAX arrays use 64-bit precision.
           Falls back to the ``JAX_ENABLE_X64`` environment variable
           when False.


.. py:function:: set_platform(platform = None)

   Select the JAX compute platform (cpu, gpu, or tpu).

   Must be called before any JAX computation.

   Args:
       platform: One of ``'cpu'``, ``'gpu'``, or ``'tpu'``.
           Defaults to the ``JAX_PLATFORMS`` environment variable,
           or ``'cpu'`` if unset.


.. py:function:: set_host_device_count(n)

   Expose *n* CPU cores as separate XLA devices.

   By default XLA treats all CPU cores as one device. This function
   sets the ``XLA_FLAGS`` environment variable so that
   :func:`jax.pmap` can distribute work across *n* host devices.

   Must be called before any JAX computation.

   Args:
       n: Number of CPU devices to expose.


