Source code for rheedium.ucell.helper

"""Helper functions for unit cell calculations and transformations.

Extended Summary
----------------
This module provides utility functions for crystallographic calculations,
including vector operations, lattice parameter computations, and crystal
structure filtering based on geometric criteria.

Routine Listings
----------------
:func:`angle_in_degrees`
    Calculate the angle in degrees between two vectors.
:func:`compute_lengths_angles`
    Compute unit cell lengths and angles from lattice vectors.
:func:`parse_cif_and_scrape`
    Parse CIF file and filter atoms within specified thickness.

Notes
-----
All functions are JAX-compatible and support automatic differentiation.
"""

from pathlib import Path

import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Tuple, Union
from jaxtyping import Array, Bool, Float, Real, jaxtyped

import rheedium as rh
from rheedium.types import CrystalStructure, create_crystal_structure


[docs] @jaxtyped(typechecker=beartype) def angle_in_degrees( v1: Float[Array, "n"], v2: Float[Array, "n"] ) -> Float[Array, ""]: """Calculate the angle in degrees between two vectors. As long as the vectors have the same number of elements, any dimensional vectors will work. Parameters ---------- v1 : Float[Array, "n"] First vector v2 : Float[Array, "n"] Second vector Returns ------- angle : Float[Array, ""] Angle between vectors in degrees Examples -------- >>> import jax.numpy as jnp >>> import rheedium as rh >>> v1 = jnp.array([1.0, 0.0, 0.0]) >>> v2 = jnp.array([0.0, 1.0, 0.0]) >>> angle = rh.ucell.angle_in_degrees(v1, v2) >>> print(angle) 90.0 """ def _check_vector_dimensions() -> Tuple[ Float[Array, "n"], Float[Array, "n"] ]: return jax.lax.cond( v1.shape == v2.shape, lambda: (v1, v2), lambda: jax.lax.stop_gradient( jax.lax.cond(False, lambda: (v1, v2), lambda: (v1, v2)) ), ) _check_vector_dimensions() angle: Float[Array, ""] = ( 180.0 * jnp.arccos( jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2)) ) / jnp.pi ) return angle
[docs] @jaxtyped(typechecker=beartype) def compute_lengths_angles( vectors: Float[Array, "3 3"], ) -> Tuple[Float[Array, "3"], Float[Array, "3"]]: """Compute unit cell lengths and angles from lattice vectors. Parameters ---------- vectors : Float[Array, "3 3"] Lattice vectors as rows of a 3x3 matrix Returns ------- lengths : Float[Array, "3"] Unit cell lengths in angstroms angles : Float[Array, "3"] Unit cell angles in degrees Examples -------- >>> import jax.numpy as jnp >>> import rheedium as rh >>> # Cubic unit cell with a=5.0 Å >>> vectors = jnp.array( ... [ ... [5.0, 0.0, 0.0], ... [0.0, 5.0, 0.0], ... [0.0, 0.0, 5.0], ... ] ... ) >>> lengths, angles = rh.ucell.compute_lengths_angles(vectors) >>> print(lengths) [5.0 5.0 5.0] >>> print(angles) [90.0 90.0 90.0] """ lengths: Float[Array, "3"] = jnp.array( [jnp.linalg.norm(v) for v in vectors] ) angles: Float[Array, "3"] = jnp.array( [ angle_in_degrees(vectors[1], vectors[2]), angle_in_degrees(vectors[0], vectors[2]), angle_in_degrees(vectors[0], vectors[1]), ] ) return (lengths, angles)
[docs] @jaxtyped(typechecker=beartype) def parse_cif_and_scrape( cif_path: Union[str, Path], zone_axis: Real[Array, " 3"], thickness_xyz: Real[Array, " 3"], ) -> CrystalStructure: """Parse a CIF file and filter atoms within specified thickness. Parse a CIF file, apply symmetry operations to obtain all equivalent atomic positions, and scrape (filter) atoms within specified thickness along a given zone axis. Parameters ---------- cif_path : Union[str, Path] Path to the CIF file. zone_axis : Real[Array, " 3"] Vector indicating the zone axis direction (surface normal) in Cartesian coordinates. thickness_xyz : Real[Array, " 3"] Thickness along x, y, z directions in Ångstroms; currently, only thickness_xyz[2] (z-direction) is used to filter atoms along the provided zone axis. Returns ------- filtered_crystal : CrystalStructure Crystal structure containing atoms filtered within the specified thickness. Notes ----- - The provided ``zone_axis`` is normalized internally. Current implementation uses thickness only along the zone axis direction (z-component of ``thickness_xyz``). - The ``tolerance`` parameter is reserved for compatibility and future functionality. 1. **Parse CIF** -- Load the CIF file to obtain the initial crystal structure with all symmetry-equivalent positions. 2. **Extract Coordinates** -- Separate Cartesian positions and atomic numbers from the crystal structure arrays. 3. **Normalize Zone Axis** -- Compute a unit vector along the provided zone axis direction. 4. **Project onto Zone Axis** -- Dot each atomic position with the zone axis unit vector to get scalar projections. 5. **Apply Thickness Filter** -- Compute the center projection and half thickness, then create a boolean mask for atoms within range. 6. **Filter Atoms** -- Select Cartesian positions and atomic numbers using the thickness mask. 7. **Reconstruct Fractional Coordinates** -- Build cell vectors from crystal parameters, invert the matrix, and convert filtered Cartesian positions to fractional coordinates. 8. **Build Output Structure** -- Create a new ``CrystalStructure`` with the filtered positions and original cell parameters. """ crystal: CrystalStructure = rh.inout.parse_cif(cif_path) cart_xyz: Float[Array, "n 3"] = crystal.cart_positions[:, :3] atomic_numbers: Float[Array, "n 1"] = crystal.cart_positions[:, 3:4] zone_axis_norm: Float[Array, ""] = jnp.linalg.norm(zone_axis) zone_axis_hat: Float[Array, "3"] = zone_axis / (zone_axis_norm + 1e-12) projections: Float[Array, "n"] = cart_xyz @ zone_axis_hat min_proj: Float[Array, ""] = jnp.min(projections) max_proj: Float[Array, ""] = jnp.max(projections) center_proj: Float[Array, ""] = (max_proj + min_proj) / 2.0 half_thickness: Float[Array, ""] = thickness_xyz[2] / 2.0 mask: Bool[Array, "n"] = ( jnp.abs(projections - center_proj) <= half_thickness ) filtered_cart_xyz: Float[Array, "m 3"] = cart_xyz[mask] filtered_atomic_numbers: Float[Array, "m 1"] = atomic_numbers[mask] cell_vectors: Float[Array, "3 3"] = rh.ucell.build_cell_vectors( crystal.cell_lengths[0], crystal.cell_lengths[1], crystal.cell_lengths[2], crystal.cell_angles[0], crystal.cell_angles[1], crystal.cell_angles[2], ) cell_inv: Float[Array, "3 3"] = jnp.linalg.inv(cell_vectors) filtered_frac_xyz: Float[Array, "m 3"] = ( filtered_cart_xyz @ cell_inv ) % 1.0 filtered_frac_positions: Float[Array, "m 4"] = jnp.concatenate( [filtered_frac_xyz, filtered_atomic_numbers], axis=1 ) filtered_cart_positions: Float[Array, "m 4"] = jnp.concatenate( [filtered_cart_xyz, filtered_atomic_numbers], axis=1 ) filtered_crystal: CrystalStructure = create_crystal_structure( frac_positions=filtered_frac_positions, cart_positions=filtered_cart_positions, cell_lengths=crystal.cell_lengths, cell_angles=crystal.cell_angles, ) return filtered_crystal
__all__: list[str] = [ "angle_in_degrees", "compute_lengths_angles", "parse_cif_and_scrape", ]