hwoutils
========

.. py:module:: hwoutils

.. autoapi-nested-parse::

   hwoutils -- Shared JAX-based utilities for the HWO simulation suite.



Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/hwoutils/constants/index
   /autoapi/hwoutils/conversions/index
   /autoapi/hwoutils/fft/index
   /autoapi/hwoutils/jax_config/index
   /autoapi/hwoutils/map_coordinates/index
   /autoapi/hwoutils/radial/index
   /autoapi/hwoutils/snapshot/index
   /autoapi/hwoutils/transforms/index


Attributes
----------

.. autoapisummary::

   hwoutils.__version__


Functions
---------

.. autoapisummary::

   hwoutils.enable_x64
   hwoutils.set_host_device_count
   hwoutils.set_platform


Package Contents
----------------

.. py:data:: __version__
   :value: 'unknown'


.. py:function:: enable_x64(use_x64 = True)

   Enable 64-bit floating-point precision in JAX.

   By default JAX uses 32-bit precision. Call this before any array
   operations to switch to 64-bit (NumPy-compatible) precision.

   Args:
       use_x64: When True, JAX arrays use 64-bit precision.
           Falls back to the ``JAX_ENABLE_X64`` environment variable
           when False.


.. py:function:: set_host_device_count(n)

   Expose *n* CPU cores as separate XLA devices.

   By default XLA treats all CPU cores as one device. This function
   sets the ``XLA_FLAGS`` environment variable so that
   :func:`jax.pmap` can distribute work across *n* host devices.

   Must be called before any JAX computation.

   Args:
       n: Number of CPU devices to expose.


.. py:function:: set_platform(platform = None)

   Select the JAX compute platform (cpu, gpu, or tpu).

   Must be called before any JAX computation.

   Args:
       platform: One of ``'cpu'``, ``'gpu'``, or ``'tpu'``.
           Defaults to the ``JAX_PLATFORMS`` environment variable,
           or ``'cpu'`` if unset.


