Source code for rheedium.tools.parallel

"""Parallel processing utilities for distributed RHEED simulations.

Extended Summary
----------------
Provides utilities for sharding arrays across multiple devices
for parallel processing and distributed computing in RHEED
simulation workflows. All functions are JAX-compatible and
support automatic differentiation.

Routine Listings
----------------
:func:`shard_array`
    Shard an array across specified axes and devices for
    parallel processing.

Notes
-----
This module is designed for distributed computing scenarios
where large arrays need to be processed across multiple
devices. The sharding utilities work with JAX's device mesh
system and can be used with various JAX transformations
including ``jit``, ``grad``, and ``vmap``.
"""

import jax
from jax.sharding import NamedSharding, PartitionSpec
from jaxtyping import Array, Num


[docs] def shard_array( input_array: Num[Array, " ..."], shard_axes: int | list[int] | tuple[int, ...], devices: list[jax.Device] | tuple[jax.Device, ...] | None = None, ) -> Num[Array, " ..."]: """Shard an array across specified axes and devices. Extended Summary ---------------- Distributes an array across multiple devices for parallel processing by creating a device mesh and applying appropriate partitioning based on the specified axes. Parameters ---------- input_array : Array The input array to be sharded. shard_axes : int | Sequence[int] The axis or axes to shard along. Use ``-1`` (or a sequence containing ``-1``) to skip sharding along that axis. devices : Sequence[jax.Device], optional The devices to shard across. If ``None``, all available devices are used. Returns ------- sharded_array : Array The array distributed across the specified devices. """ if devices is None: devices = jax.devices() if isinstance(shard_axes, int): shard_axes = [shard_axes] num_devices: int = len(devices) mesh: jax.sharding.Mesh = jax.make_mesh( (num_devices,), ("devices",), ) pspec_list: list[str | None] = [None] * input_array.ndim for ax in shard_axes: if ax != -1 and ax < input_array.ndim: pspec_list[ax] = "devices" pspec: PartitionSpec = PartitionSpec(*pspec_list) sharding: NamedSharding = NamedSharding(mesh, pspec) with mesh: return jax.device_put(input_array, sharding)
__all__: list[str] = [ "shard_array", ]