rheedium.tools

Utility tools for parallel processing and distributed computing in RHEED simulation workflows.

Reusable numerical and workflow tools for rheedium.

Extended Summary

This package provides the shared numerical infrastructure used by the simulation and crystallography modules. It centralises special-function kernels (modified Bessel functions), quadrature helpers, JAX compatibility wrappers, electron-beam utility functions, and distributed-array sharding so that domain modules depend on a single, well-tested toolbox rather than reimplementing low-level numerics.

All functions are JAX-compatible, JIT-safe, and support automatic differentiation unless noted otherwise.

Routine Listings

bessel_k0()

Modified Bessel function of the second kind, order zero.

bessel_k1()

Modified Bessel function of the second kind, order one.

bessel_kv()

Modified Bessel function of the second kind, arbitrary real order.

gauss_hermite_nodes_weights()

Gauss-Hermite quadrature nodes and weights for Gaussian averaging integrals.

incident_wavevector()

Calculate incident electron wavevector from beam parameters.

interaction_constant()

Relativistic electron interaction constant for multislice calculations.

jax_safe()

Wrap a function to convert positional arguments to JAX arrays before dispatch.

shard_array()

Shard an array across devices for parallel processing.

wavelength_ang()

Calculate relativistic electron wavelength in angstroms.

Notes

The Bessel implementations use piecewise polynomial approximations (Abramowitz & Stegun) for bessel_k0() and bessel_k1(), and series / asymptotic expansions for the general-order bessel_kv(). These are needed by the Lobato-van Dyck projected potential, which is expressed analytically in terms of \(K_0\) and \(K_1\).

The electron-beam utilities (wavelength_ang(), incident_wavevector(), interaction_constant()) live here rather than in rheedium.simul to break circular import chains between simulation sub-modules.

rheedium.tools.shard_array(input_array: Num[Array, '...'], shard_axes: int | list[int] | tuple[int, ...], devices: list[Device] | tuple[Device, ...] | None = None) Num[Array, '...'][source]

Shard an array across specified axes and devices.

Extended Summary

Distributes an array across multiple devices for parallel processing by creating a device mesh and applying appropriate partitioning based on the specified axes.

param input_array:

The input array to be sharded.

type input_array:

']

param shard_axes:

The axis or axes to shard along. Use -1 (or a sequence containing -1) to skip sharding along that axis.

type shard_axes:

int | list[int] | tuple[int, ...]

param devices:

The devices to shard across. If None, all available devices are used.

type devices:

list[Device] | tuple[Device, ...] | None, default: None

returns:

sharded_array – The array distributed across the specified devices.

rtype:

']

rheedium.tools.gauss_hermite_nodes_weights(n_points: int) tuple[Float[Array, 'N'], Float[Array, 'N']][source]

Compute Gauss-Hermite quadrature nodes and weights.

Return type:

tuple[Float[Array, 'N'], Float[Array, 'N']]

rheedium.tools.incident_wavevector(lam_ang: float | Float[Array, ''], theta_deg: float | Float[Array, ''], phi_deg: float | Float[Array, ''] = 0.0) Float[Array, '3'][source]

Calculate the incident electron wavevector for RHEED geometry.

Parameters:
  • lam_ang (Union[float, Float[Array, '']]) – Electron wavelength in angstroms.

  • theta_deg (Union[float, Float[Array, '']]) – Grazing angle of incidence in degrees (angle from surface).

  • phi_deg (Union[float, Float[Array, '']], default: 0.0) – Azimuthal angle in degrees (in-plane rotation). phi=0: beam along +x axis (default, gives horizontal streaks) phi=90: beam along +y axis (gives vertical streaks) Default: 0.0

Returns:

k_in – Incident wavevector [k_x, k_y, k_z] in reciprocal angstroms. The beam propagates in the surface plane at azimuthal angle phi, with a downward z-component determined by the grazing angle theta.

Return type:

Float[Array, '3']

Notes

  1. Compute wavevector magnitude\(k = 2\\pi / \\lambda\).

  2. Convert angles – Convert grazing and azimuthal angles from degrees to radians.

  3. Decompose into components – Split \(k\) into in-plane (\(k_x\), \(k_y\)) and surface-normal (\(k_z\)) components using trigonometric projection.

rheedium.tools.interaction_constant(voltage_kv: float | Float[Array, ''], wavelength_ang: float | Float[Array, '']) Float[Array, ''][source]

Relativistic electron interaction constant σ in 1/(V·Å).

Extended Summary

Computes the relativistic interaction constant \(\\sigma\) used in multislice calculations. Includes relativistic mass correction via the Lorentz factor.

Notes

  1. Convert units – Convert voltage from kV to V and wavelength from Ångstroms to metres.

  2. Compute Lorentz factor – Calculate relativistic \(\\gamma\) from accelerating voltage.

  3. Evaluate interaction constant\(\\sigma = (2\\pi m_e e \\lambda / h^2) \\times \\gamma\) in SI, then convert to \(1/(V \\cdot \\text{Å})\).

param voltage_kv:

Accelerating voltage in kilovolts.

type voltage_kv:

Union[float, Float[Array, '']]

param wavelength_ang:

Relativistic electron wavelength in angstroms.

type wavelength_ang:

Union[float, Float[Array, '']]

returns:

sigma – Interaction constant σ (1 / (Volt · Ångstrom)).

rtype:

Float[Array, '']

rheedium.tools.wavelength_ang(voltage_kv: int | float | Num[Array, ''] | Num[Array, '...']) Float[Array, '...'][source]

Calculate the relativistic electron wavelength in angstroms.

Extended Summary

Uses the full relativistic de Broglie wavelength formula:

lambda = h / sqrt(2 * m_e * e * V * (1 + e*V / (2 * m_e * c^2)))

This is more accurate than simplified approximations, especially at higher voltages (>=30 keV) where the difference can be several percent.

param voltage_kv:

Electron energy in kiloelectron volts. Could be either a scalar or an array.

type voltage_kv:

Union[int, float, Num[Array, ''], ']]

returns:

wavelength – Electron wavelength in angstroms.

rtype:

']

Notes

Physical constants used: - h = 6.62607015e-34 J·s (Planck constant, exact) - m_e = 9.1093837015e-31 kg (electron mass) - e = 1.602176634e-19 C (elementary charge, exact) - c = 299792458 m/s (speed of light, exact)

The formula simplifies to:

lambda(Å) = 12.2643 / sqrt(V * (1 + 0.978476e-6 * V))

where V is in volts and the coefficient 0.978476e-6 = e / (2 * m_e * c^2).

  1. Convert voltage – Multiply kV by 1000 to obtain voltage in Volts.

  2. Relativistic correction – Compute corrected voltage \(V_{corr} = V (1 + eV / 2 m_e c^2)\).

  3. Wavelength calculation – Compute \(\\lambda = h / \\sqrt{2 m_e e V_{corr}}\) and return in Ångstroms.

Examples

>>> import rheedium as rh
>>> import jax.numpy as jnp
>>> lam = rh.tools.wavelength_ang(jnp.asarray(20.0))  # 20 keV
>>> print(f"λ = {lam:.4f} Å")
λ = 0.0859 Å
rheedium.tools.bessel_k0(x: Float[Array, '...']) Float[Array, '...'][source]

Compute modified Bessel function of the second kind, order zero.

Return type:

']

rheedium.tools.bessel_k1(x: Float[Array, '...']) Float[Array, '...'][source]

Compute modified Bessel function of the second kind, order one.

Return type:

']

rheedium.tools.bessel_kv(v: float | Float[Array, ''], x: Float[Array, '...']) Float[Array, '...'][source]

Compute the modified Bessel function of the second kind K_v(x).

Return type:

']

rheedium.tools.jax_safe(fn: Callable[[...], Any]) Callable[[...], Any][source]

Wrap a function to convert positional args to JAX arrays.

Parameters:

fn (Callable[..., Any]) – Function whose scalar positional arguments should be converted via jnp.asarray before dispatch.

Returns:

wrapper – Wrapped function that calls jnp.asarray on each positional argument.

Return type:

Callable[..., Any]

Notes

  1. Iterate over all positional arguments passed to fn.

  2. Call jnp.asarray on each, converting numpy scalars and arrays to their JAX equivalents.

  3. Forward the converted arguments to fn and return the result unchanged.

This is required when using jax.test_util.check_grads, which perturbs inputs via numpy arithmetic. The perturbed values are numpy scalars (f64[](numpy)) that fail beartype’s Float[Array, ''] checks. Wrapping the function under test with jax_safe resolves this.