Source code for rheedium.tools.wrappers
"""Function wrappers for JAX compatibility.
Extended Summary
----------------
Provides decorator-style wrappers that bridge compatibility gaps
between JAX and external tools. The primary use case is ensuring
that functions decorated with ``beartype`` + ``jaxtyping`` accept
inputs from tools that produce numpy arrays instead of JAX arrays
(e.g. ``jax.test_util.check_grads``).
Routine Listings
----------------
:func:`jax_safe`
Wrap a function to convert all positional arguments to JAX
arrays before calling.
Notes
-----
These wrappers are intentionally minimal and composable. They do
not modify the return value or keyword arguments of the wrapped
function.
"""
from collections.abc import Callable
from typing import Any
import jax.numpy as jnp
from beartype import beartype
from jaxtyping import jaxtyped
[docs]
@jaxtyped(typechecker=beartype)
def jax_safe(fn: Callable[..., Any]) -> Callable[..., Any]:
"""Wrap a function to convert positional args to JAX arrays.
Parameters
----------
fn : Callable[..., Any]
Function whose scalar positional arguments should be
converted via ``jnp.asarray`` before dispatch.
Returns
-------
wrapper : Callable[..., Any]
Wrapped function that calls ``jnp.asarray`` on each
positional argument.
Notes
-----
1. Iterate over all positional arguments passed to ``fn``.
2. Call ``jnp.asarray`` on each, converting numpy scalars
and arrays to their JAX equivalents.
3. Forward the converted arguments to ``fn`` and return
the result unchanged.
This is required when using ``jax.test_util.check_grads``,
which perturbs inputs via numpy arithmetic. The perturbed
values are numpy scalars (``f64[](numpy)``) that fail
beartype's ``Float[Array, '']`` checks. Wrapping the
function under test with ``jax_safe`` resolves this.
"""
def wrapper(*args: Any) -> Any:
return fn(*(jnp.asarray(a) for a in args))
return wrapper
__all__: list[str] = [
"jax_safe",
]