Source code for rheedium.ucell.unitcell

"""Functions for unit cell calculations and transformations.

Extended Summary
----------------
This module provides functions for crystallographic unit cell operations
including reciprocal space calculations, lattice transformations, and atom
filtering for specific zones and thicknesses.

Routine Listings
----------------
:func:`bulk_to_slice`
    Transform bulk crystal structure into surface-oriented
    slab for RHEED simulation.
:func:`reciprocal_unitcell`
    Calculate reciprocal unit cell parameters from direct
    cell parameters.
:func:`get_unit_cell_matrix`
    Build transformation matrix between direct and
    reciprocal space.
:func:`build_cell_vectors`
    Construct unit cell vectors from lengths and angles.
:func:`compute_lengths_angles`
    Compute unit cell lengths and angles from lattice
    vectors.
:func:`generate_reciprocal_points`
    Generate reciprocal lattice points for given hkl
    ranges.
:func:`atom_scraper`
    Filter atoms within specified thickness along zone
    axis.
:func:`reciprocal_lattice_vectors`
    Generate reciprocal lattice basis vectors b₁, b₂, b₃.
:func:`miller_to_reciprocal`
    Convert Miller indices to reciprocal space vectors.

Notes
-----
All functions are JAX-compatible and support automatic differentiation.
"""

import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Tuple
from jax import lax
from jaxtyping import Array, Bool, Float, Int, Num, jaxtyped

from rheedium.types import (
    CrystalStructure,
    SlicedCrystal,
    create_crystal_structure,
    create_sliced_crystal,
    scalar_bool,
    scalar_float,
    scalar_int,
)


[docs] @jaxtyped(typechecker=beartype) def reciprocal_unitcell( a: scalar_float, b: scalar_float, c: scalar_float, alpha: scalar_float, beta: scalar_float, gamma: scalar_float, in_degrees: scalar_bool = True, out_degrees: scalar_bool = True, ) -> Tuple[Float[Array, "3"], Float[Array, "3"]]: """ Calculate reciprocal unit cell parameters from direct cell parameters. Computes reciprocal lattice parameters (a*, b*, c*, α*, β*, γ*) from direct lattice parameters using crystallographic relationships. Parameters ---------- a : scalar_float Direct cell length a in Angstroms b : scalar_float Direct cell length b in Angstroms c : scalar_float Direct cell length c in Angstroms alpha : scalar_float Direct cell angle α (between b and c axes) beta : scalar_float Direct cell angle β (between a and c axes) gamma : scalar_float Direct cell angle γ (between a and b axes) in_degrees : bool If True, input angles are in degrees. Default: True out_degrees : bool If True, output angles are in degrees. Default: True Returns ------- reciprocal_lengths : Float[Array, "3"] Reciprocal cell lengths [a*, b*, c*] in 1/Angstroms reciprocal_angles : Float[Array, "3"] Reciprocal cell angles [α*, β*, γ*] in degrees or radians Notes ----- 1. **Angle Conversion** -- Convert input angles to radians if provided in degrees. 2. **Volume Calculation** -- Compute unit cell volume using the triple product formula with cosines of cell angles. 3. **Reciprocal Lengths** -- Derive a*, b*, c* from the volume and direct cell parameters using crystallographic relationships. 4. **Reciprocal Angles** -- Calculate alpha*, beta*, gamma* from the direct cell angle cosines and sines. 5. **Output Conversion** -- Convert output angles to degrees if requested. See Also -------- reciprocal_lattice_vectors : Generate reciprocal basis vectors. build_cell_vectors : Build direct lattice vectors. """ alpha_rad: Float[Array, ""] = jnp.where( in_degrees, jnp.deg2rad(alpha), alpha ) beta_rad: Float[Array, ""] = jnp.where(in_degrees, jnp.deg2rad(beta), beta) gamma_rad: Float[Array, ""] = jnp.where( in_degrees, jnp.deg2rad(gamma), gamma ) cos_alpha: Float[Array, ""] = jnp.cos(alpha_rad) cos_beta: Float[Array, ""] = jnp.cos(beta_rad) cos_gamma: Float[Array, ""] = jnp.cos(gamma_rad) sin_alpha: Float[Array, ""] = jnp.sin(alpha_rad) sin_beta: Float[Array, ""] = jnp.sin(beta_rad) sin_gamma: Float[Array, ""] = jnp.sin(gamma_rad) volume_squared: Float[Array, ""] = ( 1.0 - cos_alpha**2 - cos_beta**2 - cos_gamma**2 + 2.0 * cos_alpha * cos_beta * cos_gamma ) volume: Float[Array, ""] = a * b * c * jnp.sqrt(volume_squared) a_star: Float[Array, ""] = 2.0 * jnp.pi * b * c * sin_alpha / volume b_star: Float[Array, ""] = 2.0 * jnp.pi * a * c * sin_beta / volume c_star: Float[Array, ""] = 2.0 * jnp.pi * a * b * sin_gamma / volume cos_alpha_star: Float[Array, ""] = (cos_beta * cos_gamma - cos_alpha) / ( sin_beta * sin_gamma ) cos_beta_star: Float[Array, ""] = (cos_alpha * cos_gamma - cos_beta) / ( sin_alpha * sin_gamma ) cos_gamma_star: Float[Array, ""] = (cos_alpha * cos_beta - cos_gamma) / ( sin_alpha * sin_beta ) alpha_star_rad: Float[Array, ""] = jnp.arccos( jnp.clip(cos_alpha_star, -1.0, 1.0) ) beta_star_rad: Float[Array, ""] = jnp.arccos( jnp.clip(cos_beta_star, -1.0, 1.0) ) gamma_star_rad: Float[Array, ""] = jnp.arccos( jnp.clip(cos_gamma_star, -1.0, 1.0) ) alpha_star: Float[Array, ""] = jnp.where( out_degrees, jnp.rad2deg(alpha_star_rad), alpha_star_rad ) beta_star: Float[Array, ""] = jnp.where( out_degrees, jnp.rad2deg(beta_star_rad), beta_star_rad ) gamma_star: Float[Array, ""] = jnp.where( out_degrees, jnp.rad2deg(gamma_star_rad), gamma_star_rad ) reciprocal_lengths: Float[Array, "3"] = jnp.array([a_star, b_star, c_star]) reciprocal_angles: Float[Array, "3"] = jnp.array( [alpha_star, beta_star, gamma_star] ) return reciprocal_lengths, reciprocal_angles
[docs] @jaxtyped(typechecker=beartype) def get_unit_cell_matrix( a: scalar_float, b: scalar_float, c: scalar_float, alpha: scalar_float, beta: scalar_float, gamma: scalar_float, ) -> Float[Array, "3 3"]: r"""Build transformation matrix between direct and reciprocal space. Parameters ---------- a, b, c : scalar_float Direct cell lengths in angstroms. alpha, beta, gamma : scalar_float Direct cell angles in degrees. Returns ------- Float[Array, "3 3"] Transformation matrix from direct to reciprocal space. Notes ----- 1. **Angle Conversion** -- Convert cell angles from degrees to radians and compute their cosines and sines. 2. **Volume Factor** -- Compute the volume factor from the cosines of all three cell angles. 3. **Matrix Assembly** -- Populate the 3x3 transformation matrix with elements derived from lengths, angles, and the volume factor. The matrix maps direct space coordinates to Cartesian coordinates. Examples -------- >>> import rheedium as rh >>> import jax.numpy as jnp >>> >>> # Get transformation matrix for a cubic cell >>> matrix = get_unit_cell_matrix( ... a=3.0, ... b=3.0, ... c=3.0, # 3 Å cubic cell ... alpha=90.0, ... beta=90.0, ... gamma=90.0, ... ) >>> print(f"Transformation matrix:\n{matrix}") >>> >>> # Transform a direct space vector to reciprocal space >>> direct_vec = jnp.array([1.0, 0.0, 0.0]) >>> recip_vec = direct_vec @ matrix >>> print(f"Reciprocal vector: {recip_vec}") See Also -------- build_cell_vectors : Build direct lattice vectors from parameters. reciprocal_lattice_vectors : Generate reciprocal basis vectors. """ alpha_rad: Float[Array, ""] = jnp.radians(alpha) beta_rad: Float[Array, ""] = jnp.radians(beta) gamma_rad: Float[Array, ""] = jnp.radians(gamma) cos_angles: Float[Array, "3"] = jnp.array( [jnp.cos(alpha_rad), jnp.cos(beta_rad), jnp.cos(gamma_rad)] ) sin_angles: Float[Array, "3"] = jnp.array( [jnp.sin(alpha_rad), jnp.sin(beta_rad), jnp.sin(gamma_rad)] ) volume_factor: Float[Array, ""] = jnp.sqrt( 1 - jnp.sum(jnp.square(cos_angles)) + (2 * jnp.prod(cos_angles)) ) matrix: Float[Array, "3 3"] = jnp.zeros(shape=(3, 3), dtype=jnp.float64) matrix = matrix.at[0, 0].set(a) matrix = matrix.at[0, 1].set(b * cos_angles[2]) matrix = matrix.at[0, 2].set(c * cos_angles[1]) matrix = matrix.at[1, 1].set(b * sin_angles[2]) matrix = matrix.at[1, 2].set( c * (cos_angles[0] - cos_angles[1] * cos_angles[2]) / sin_angles[2] ) matrix_assigned: Float[Array, "3 3"] = matrix.at[2, 2].set( c * volume_factor / sin_angles[2] ) return matrix_assigned
[docs] @jaxtyped(typechecker=beartype) def build_cell_vectors( a: scalar_float, b: scalar_float, c: scalar_float, alpha: scalar_float, beta: scalar_float, gamma: scalar_float, ) -> Float[Array, "3 3"]: r"""Construct unit cell vectors from lengths and angles. Parameters ---------- a, b, c : scalar_float Direct cell lengths in angstroms. alpha, beta, gamma : scalar_float Direct cell angles in degrees. Returns ------- Float[Array, "3 3"] Unit cell vectors as rows of 3x3 matrix. Notes ----- 1. **Angle Conversion** -- Convert alpha, beta, gamma from degrees to radians. 2. **First Vector** -- Place a-vector along the x-axis as [a, 0, 0]. 3. **Second Vector** -- Place b-vector in the x-y plane using gamma to compute x and y components. 4. **Third Vector** -- Compute c-vector components from all three angles, clipping c_z squared to avoid negative values under the square root. 5. **Stack Vectors** -- Assemble the three vectors into a 3x3 matrix with each vector as a row. Examples -------- >>> import rheedium as rh >>> import jax.numpy as jnp >>> >>> # Build vectors for a cubic cell >>> vectors = build_cell_vectors( ... a=3.0, ... b=3.0, ... c=3.0, # 3 Å cubic cell ... alpha=90.0, ... beta=90.0, ... gamma=90.0, ... ) >>> print(f"Cell vectors:\n{vectors}") >>> >>> # Calculate cell volume >>> volume = jnp.linalg.det(vectors) >>> print(f"Cell volume: {volume}") See Also -------- compute_lengths_angles : Inverse operation from vectors to parameters. reciprocal_lattice_vectors : Build reciprocal lattice vectors. """ alpha_rad: Float[Array, ""] = jnp.radians(alpha) beta_rad: Float[Array, ""] = jnp.radians(beta) gamma_rad: Float[Array, ""] = jnp.radians(gamma) a_vec: Float[Array, "3"] = jnp.array([a, 0.0, 0.0]) b_x: Float[Array, ""] = b * jnp.cos(gamma_rad) b_y: Float[Array, ""] = b * jnp.sin(gamma_rad) b_vec: Float[Array, "3"] = jnp.array([b_x, b_y, 0.0]) c_x: Float[Array, ""] = c * jnp.cos(beta_rad) c_y: Float[Array, ""] = c * ( (jnp.cos(alpha_rad) - jnp.cos(beta_rad) * jnp.cos(gamma_rad)) / jnp.sin(gamma_rad) ) c_z_sq: Float[Array, ""] = (c**2) - (c_x**2) - (c_y**2) c_z: Float[Array, ""] = jnp.sqrt(jnp.clip(c_z_sq, min=0.0)) c_vec: Float[Array, "3"] = jnp.array([c_x, c_y, c_z]) cell_vectors: Float[Array, "3 3"] = jnp.stack( [a_vec, b_vec, c_vec], axis=0 ) return cell_vectors
@jaxtyped(typechecker=beartype) def compute_lengths_angles( vectors: Float[Array, "3 3"], ) -> Tuple[Float[Array, "3"], Float[Array, "3"]]: """Compute unit cell lengths and angles from lattice vectors. Parameters ---------- vectors : Float[Array, "3 3"] Unit cell vectors as rows of 3x3 matrix. Returns ------- lengths : Float[Array, "3"] Unit cell lengths [a, b, c] in angstroms angles : Float[Array, "3"] Unit cell angles [α, β, γ] in degrees Notes ----- 1. **Compute Lengths** -- Calculate the Euclidean norm of each row vector to obtain a, b, c. 2. **Compute Angles** -- For each pair of vectors, compute the dot product divided by the product of norms, then take arccos (clipped to [-1, 1]) to get alpha, beta, gamma. 3. **Convert to Degrees** -- Convert all angles from radians to degrees. Examples -------- >>> import rheedium as rh >>> import jax.numpy as jnp >>> >>> # Create some cell vectors >>> vectors = jnp.array( ... [ ... [3.0, 0.0, 0.0], # a vector ... [0.0, 3.0, 0.0], # b vector ... [0.0, 0.0, 3.0], # c vector ... ] ... ) >>> >>> # Compute lengths and angles >>> lengths, angles = compute_lengths_angles(vectors) >>> print(f"Cell lengths: {lengths}") >>> print(f"Cell angles: {angles}") See Also -------- build_cell_vectors : Inverse operation from parameters to vectors. """ lengths: Float[Array, "3"] = jnp.linalg.norm(vectors, axis=1) a_vec: Float[Array, "3"] = vectors[0] b_vec: Float[Array, "3"] = vectors[1] c_vec: Float[Array, "3"] = vectors[2] cos_alpha: Float[Array, ""] = jnp.dot(b_vec, c_vec) / ( lengths[1] * lengths[2] ) cos_beta: Float[Array, ""] = jnp.dot(a_vec, c_vec) / ( lengths[0] * lengths[2] ) cos_gamma: Float[Array, ""] = jnp.dot(a_vec, b_vec) / ( lengths[0] * lengths[1] ) alpha_rad: Float[Array, ""] = jnp.arccos(jnp.clip(cos_alpha, -1.0, 1.0)) beta_rad: Float[Array, ""] = jnp.arccos(jnp.clip(cos_beta, -1.0, 1.0)) gamma_rad: Float[Array, ""] = jnp.arccos(jnp.clip(cos_gamma, -1.0, 1.0)) angles: Float[Array, "3"] = jnp.array( [jnp.rad2deg(alpha_rad), jnp.rad2deg(beta_rad), jnp.rad2deg(gamma_rad)] ) return lengths, angles
[docs] @jaxtyped(typechecker=beartype) def generate_reciprocal_points( crystal: CrystalStructure, hmax: scalar_int, kmax: scalar_int, lmax: scalar_int, in_degrees: scalar_bool = True, ) -> Float[Array, "M 3"]: r"""Generate reciprocal-lattice vectors based on the crystal structure. Parameters ---------- crystal : CrystalStructure Crystal structure to generate points for. hmax, kmax, lmax : scalar_int Maximum h, k, l indices to generate. in_degrees : bool, optional Whether to use degrees for angles. Default: True. Returns ------- Float[Array, "M 3"] Reciprocal lattice vectors in 1/angstroms. Notes ----- 1. **Extract Cell Parameters** -- Retrieve lengths and angles from the crystal structure. 2. **Build Reciprocal Basis** -- Generate reciprocal lattice vectors from the direct cell parameters. 3. **Generate Index Grid** -- Create h, k, l index ranges and form a meshgrid of all (h, k, l) combinations. 4. **Transform to Reciprocal Space** -- Convert the flattened Miller index array to reciprocal space vectors via ``miller_to_reciprocal``. Examples -------- >>> import rheedium as rh >>> import jax.numpy as jnp >>> >>> # Load crystal structure from CIF >>> crystal = parse_cif("path/to/crystal.cif") >>> >>> # Generate reciprocal points up to (2,2,1) >>> G_vectors = generate_reciprocal_points( ... crystal=crystal, ... hmax=2, ... kmax=2, ... lmax=1, ... ) >>> print(f"Number of G vectors: {len(G_vectors)}") >>> print(f"First few G vectors:\n{G_vectors[:5]}") See Also -------- reciprocal_lattice_vectors : Generate reciprocal basis vectors. miller_to_reciprocal : Convert Miller indices to G vectors. """ abc: Num[Array, "3"] = crystal.cell_lengths angles: Num[Array, "3"] = crystal.cell_angles a: Float[Array, ""] = abc[0] b: Float[Array, ""] = abc[1] c: Float[Array, ""] = abc[2] alpha: Float[Array, ""] = angles[0] beta: Float[Array, ""] = angles[1] gamma: Float[Array, ""] = angles[2] rec_vectors: Float[Array, "3 3"] = reciprocal_lattice_vectors( a, b, c, alpha, beta, gamma, in_degrees=in_degrees ) hs: Int[Array, "n_h"] = jnp.arange(-hmax, hmax + 1) ks: Int[Array, "n_k"] = jnp.arange(-kmax, kmax + 1) ls: Int[Array, "n_l"] = jnp.arange(-lmax, lmax + 1) hh: Int[Array, "n_h n_k n_l"] kk: Int[Array, "n_h n_k n_l"] ll: Int[Array, "n_h n_k n_l"] hh, kk, ll = jnp.meshgrid(hs, ks, ls, indexing="ij") hkl: Int[Array, "M 3"] = jnp.stack( [hh.ravel(), kk.ravel(), ll.ravel()], axis=-1 ) g_vectors: Float[Array, "M 3"] = miller_to_reciprocal(hkl, rec_vectors) return g_vectors
[docs] @jaxtyped(typechecker=beartype) def atom_scraper( crystal: CrystalStructure, zone_axis: Float[Array, "3"], thickness: Float[Array, "3"], ) -> CrystalStructure: """Filter atoms within specified thickness along zone axis. Parameters ---------- crystal : CrystalStructure Crystal structure to filter. zone_axis : Float[Array, "3"] Zone axis direction. thickness : Float[Array, "3"] Thickness in each direction. Returns ------- filtered_crystal : CrystalStructure Filtered crystal structure. Notes ----- 1. **Build Cell Vectors** -- Construct direct lattice vectors from the crystal structure cell parameters. 2. **Normalize Zone Axis** -- Compute a unit vector along the provided zone axis direction. 3. **Project and Measure** -- Project each atomic position onto the zone axis and compute distances from the topmost layer. 4. **Adaptive Threshold** -- Determine an adaptive epsilon for top-layer mode based on the smallest nonzero distance. 5. **Apply Thickness Mask** -- Create a boolean mask selecting atoms within the projected thickness or within the adaptive epsilon for zero-thickness (top-layer) mode. 6. **Gather Filtered Atoms** -- Extract fractional and Cartesian positions for atoms that pass the mask. 7. **Scale Cell Vectors** -- Rescale lattice vectors along the zone axis to match the new slab height while preserving perpendicular components. 8. **Build Output Structure** -- Create a new ``CrystalStructure`` with the filtered positions and rescaled cell parameters. Examples -------- >>> import rheedium as rh >>> import jax.numpy as jnp >>> >>> # Load crystal structure >>> crystal = parse_cif("path/to/crystal.cif") >>> >>> # Filter atoms within 12 Å along [111] direction >>> filtered = atom_scraper( ... crystal=crystal, ... zone_axis=jnp.array([1.0, 1.0, 1.0]), ... thickness=jnp.array([12.0, 12.0, 12.0]), ... ) >>> print(f"Original atoms: {len(crystal.frac_positions)}") >>> print(f"Filtered atoms: {len(filtered.frac_positions)}") See Also -------- build_cell_vectors : Construct unit cell vectors. compute_lengths_angles : Compute cell parameters from vectors. create_crystal_structure : Create filtered crystal structure. """ orig_cell_vectors: Float[Array, "3 3"] = build_cell_vectors( crystal.cell_lengths[0], crystal.cell_lengths[1], crystal.cell_lengths[2], crystal.cell_angles[0], crystal.cell_angles[1], crystal.cell_angles[2], ) zone_axis_norm: Float[Array, ""] = jnp.linalg.norm(zone_axis) zone_axis_hat: Float[Array, "3"] = zone_axis / (zone_axis_norm + 1e-32) cart_xyz: Float[Array, "n 3"] = crystal.cart_positions[:, :3] dot_vals: Float[Array, "n"] = jnp.einsum( "ij,j->i", cart_xyz, zone_axis_hat ) d_max: Float[Array, ""] = jnp.max(dot_vals) dist_from_top: Float[Array, "n"] = d_max - dot_vals distance_cutoff: Float[Array, ""] = 1e-8 positive_distances: Float[Array, "m"] = dist_from_top[ dist_from_top > distance_cutoff ] adaptive_eps: Float[Array, ""] = jnp.where( positive_distances.size > 0, jnp.maximum(1e-3, 2 * jnp.min(positive_distances)), 1e-3, ) # Project thickness onto zone axis for scalar # thickness along that direction thickness_along_axis: Float[Array, ""] = jnp.abs( jnp.dot(thickness, zone_axis_hat) ) is_top_layer_mode: Bool[Array, ""] = jnp.isclose( thickness_along_axis, jnp.asarray(0.0), atol=1e-8 ) mask: Bool[Array, "n"] = jnp.where( is_top_layer_mode, dist_from_top <= adaptive_eps, dist_from_top <= thickness_along_axis, ) def _gather_valid_positions( positions: Float[Array, "n 4"], gather_mask: Bool[Array, "n"] ) -> Float[Array, "m 4"]: return positions[gather_mask] filtered_frac: Float[Array, "m 4"] = _gather_valid_positions( crystal.frac_positions, mask ) filtered_cart: Float[Array, "m 4"] = _gather_valid_positions( crystal.cart_positions, mask ) original_height: Float[Array, ""] = jnp.max(dot_vals) - jnp.min(dot_vals) new_height: Float[Array, ""] = jnp.where( is_top_layer_mode, adaptive_eps, jnp.minimum(thickness_along_axis, original_height), ) def _scale_vector( vec: Float[Array, "3"], zone_axis_hat: Float[Array, "3"], old_height: Float[Array, ""], new_height: Float[Array, ""], ) -> Float[Array, "3"]: height_cutoff: Float[Array, ""] = 1e-32 proj_mag: Float[Array, ""] = jnp.dot(vec, zone_axis_hat) parallel_comp: Float[Array, "3"] = proj_mag * zone_axis_hat perp_comp: Float[Array, "3"] = vec - parallel_comp scale_factor: Float[Array, ""] = jnp.where( old_height < height_cutoff, 1.0, new_height / old_height ) scaled_parallel: Float[Array, "3"] = scale_factor * parallel_comp return scaled_parallel + perp_comp def _scale_if_needed( vec: Float[Array, "3"], zone_axis_hat: Float[Array, "3"], original_height: Float[Array, ""], new_height: Float[Array, ""], ) -> Float[Array, "3"]: needs_scaling: Bool[Array, ""] = ( jnp.abs(jnp.dot(vec, zone_axis_hat)) > distance_cutoff ) scaled: Float[Array, "3"] = _scale_vector( vec, zone_axis_hat, original_height, new_height ) return jnp.where(needs_scaling, scaled, vec) scaled_vectors: Float[Array, "3 3"] = jnp.stack( [ _scale_if_needed( orig_cell_vectors[i], zone_axis_hat, original_height, new_height, ) for i in range(3) ] ) new_lengths: Float[Array, "3"] new_angles: Float[Array, "3"] new_lengths, new_angles = compute_lengths_angles(scaled_vectors) filtered_crystal: CrystalStructure = create_crystal_structure( frac_positions=filtered_frac, cart_positions=filtered_cart, cell_lengths=new_lengths, cell_angles=new_angles, ) return filtered_crystal
[docs] @jaxtyped(typechecker=beartype) def reciprocal_lattice_vectors( a: scalar_float, b: scalar_float, c: scalar_float, alpha: scalar_float, beta: scalar_float, gamma: scalar_float, in_degrees: scalar_bool = True, ) -> Float[Array, "3 3"]: """Generate reciprocal lattice basis vectors b₁, b₂, b₃. Computes the three reciprocal lattice basis vectors from direct lattice parameters using the crystallographic relationships: b₁ = 2π(a₂ × a₃)/(a₁ · (a₂ × a₃)) b₂ = 2π(a₃ × a₁)/(a₁ · (a₂ × a₃)) b₃ = 2π(a₁ × a₂)/(a₁ · (a₂ × a₃)) Parameters ---------- a : scalar_float Direct cell length a in Angstroms b : scalar_float Direct cell length b in Angstroms c : scalar_float Direct cell length c in Angstroms alpha : scalar_float Direct cell angle α (between b and c axes) beta : scalar_float Direct cell angle β (between a and c axes) gamma : scalar_float Direct cell angle γ (between a and b axes) in_degrees : bool If True, input angles are in degrees. Default: True Returns ------- reciprocal_vectors : Float[Array, "3 3"] Reciprocal lattice vectors as rows of 3x3 matrix in 1/Angstroms. Each row is a reciprocal basis vector [b₁, b₂, b₃]. Notes ----- 1. **Angle Conversion** -- Convert input angles to radians if provided in degrees. 2. **Build Direct Vectors** -- Construct direct lattice vectors a₁, a₂, a₃ via ``build_cell_vectors``. 3. **Compute Volume** -- Calculate the unit cell volume from the triple product a₁ . (a₂ x a₃). 4. **Cross Products** -- Compute cross products (a₂ x a₃), (a₃ x a₁), and (a₁ x a₂) for each reciprocal vector. 5. **Scale to Reciprocal Space** -- Multiply each cross product by 2pi/volume to obtain b₁, b₂, b₃. 6. **Stack Vectors** -- Assemble the three reciprocal vectors into a 3x3 matrix with each vector as a row. See Also -------- build_cell_vectors : Build direct lattice vectors. reciprocal_unitcell : Compute reciprocal cell parameters. miller_to_reciprocal : Convert Miller indices to G vectors. """ alpha_rad: Float[Array, ""] = jnp.where( in_degrees, jnp.deg2rad(alpha), alpha ) beta_rad: Float[Array, ""] = jnp.where(in_degrees, jnp.deg2rad(beta), beta) gamma_rad: Float[Array, ""] = jnp.where( in_degrees, jnp.deg2rad(gamma), gamma ) direct_vectors: Float[Array, "3 3"] = build_cell_vectors( a, b, c, jnp.rad2deg(alpha_rad), jnp.rad2deg(beta_rad), jnp.rad2deg(gamma_rad), ) a_vec: Float[Array, "3"] = direct_vectors[0] b_vec: Float[Array, "3"] = direct_vectors[1] c_vec: Float[Array, "3"] = direct_vectors[2] cross_b_c: Float[Array, "3"] = jnp.cross(b_vec, c_vec) cross_c_a: Float[Array, "3"] = jnp.cross(c_vec, a_vec) cross_a_b: Float[Array, "3"] = jnp.cross(a_vec, b_vec) volume: Float[Array, ""] = jnp.dot(a_vec, cross_b_c) two_pi: Float[Array, ""] = 2.0 * jnp.pi scale_factor: Float[Array, ""] = two_pi / volume b1_vec: Float[Array, "3"] = scale_factor * cross_b_c b2_vec: Float[Array, "3"] = scale_factor * cross_c_a b3_vec: Float[Array, "3"] = scale_factor * cross_a_b reciprocal_vectors: Float[Array, "3 3"] = jnp.stack( [b1_vec, b2_vec, b3_vec], axis=0 ) return reciprocal_vectors
[docs] @jaxtyped(typechecker=beartype) def miller_to_reciprocal( hkl: Int[Array, "... 3"], reciprocal_vectors: Float[Array, "3 3"], ) -> Float[Array, "... 3"]: """Convert Miller indices to reciprocal space vectors. Transforms Miller indices (h,k,l) to reciprocal space vectors G using the reciprocal lattice basis vectors. Each reciprocal vector is computed as G = h*b₁ + k*b₂ + l*b₃ where b₁, b₂, b₃ are the reciprocal lattice basis vectors. Parameters ---------- hkl : Int[Array, "... 3"] Miller indices with shape (..., 3) where the last dimension contains [h, k, l] values. Can be a single set of indices or a batch of multiple indices. reciprocal_vectors : Float[Array, "3 3"] Reciprocal lattice basis vectors as rows of 3x3 matrix in 1/Angstroms, as returned by reciprocal_lattice_vectors function Returns ------- g_vectors : Float[Array, "... 3"] Reciprocal space vectors in 1/Angstroms with same batch shape as input hkl indices Notes ----- 1. **Cast to Float** -- Convert integer Miller indices to float for computation with reciprocal vectors. 2. **Extract Basis Vectors** -- Retrieve individual reciprocal basis vectors b₁, b₂, b₃ from the input matrix rows. 3. **Linear Combination** -- Compute G = h*b₁ + k*b₂ + l*b₃ using element-wise broadcasting for efficient batched computation over all (h, k, l) triplets. See Also -------- reciprocal_lattice_vectors : Generate reciprocal basis vectors. generate_reciprocal_points : Generate G vectors from crystal structure. """ hkl_float: Float[Array, "... 3"] = jnp.asarray(hkl, dtype=jnp.float64) b1_vec: Float[Array, "3"] = reciprocal_vectors[0] b2_vec: Float[Array, "3"] = reciprocal_vectors[1] b3_vec: Float[Array, "3"] = reciprocal_vectors[2] h_component: Float[Array, "..."] = hkl_float[..., 0] k_component: Float[Array, "..."] = hkl_float[..., 1] l_component: Float[Array, "..."] = hkl_float[..., 2] h_contribution: Float[Array, "... 3"] = ( h_component[..., jnp.newaxis] * b1_vec ) k_contribution: Float[Array, "... 3"] = ( k_component[..., jnp.newaxis] * b2_vec ) l_contribution: Float[Array, "... 3"] = ( l_component[..., jnp.newaxis] * b3_vec ) g_vectors: Float[Array, "... 3"] = ( h_contribution + k_contribution + l_contribution ) return g_vectors
[docs] @jaxtyped(typechecker=beartype) def bulk_to_slice( # noqa: PLR0915 bulk_crystal: CrystalStructure, orientation: Int[Array, "3"], depth: scalar_float, x_extent: scalar_float = 150.0, y_extent: scalar_float = 150.0, ) -> SlicedCrystal: """Transform a bulk crystal structure into a surface-oriented slab. This function takes a bulk crystal structure and creates a surface slab oriented along the specified Miller indices. The slab is extended in the x and y directions to create a large surface area suitable for RHEED simulations, with atoms selected within a specified depth from the surface. Parameters ---------- bulk_crystal : CrystalStructure Bulk crystal structure from CIF file or other source. orientation : Int[Array, "3"] Miller indices [h, k, l] defining the desired surface orientation. Example: [1, 1, 1] for (111) surface, [0, 0, 1] for (001). depth : scalar_float Depth of atoms to include perpendicular to surface in Angstroms. Typically 10-30 Angstroms to capture surface effects. x_extent : scalar_float, optional Lateral extent in x-direction in Angstroms. Default: 150.0. Should be >= 100 Angstroms for realistic RHEED simulations. y_extent : scalar_float, optional Lateral extent in y-direction in Angstroms. Default: 150.0. Should be >= 100 Angstroms for realistic RHEED simulations. Returns ------- sliced_crystal : SlicedCrystal Surface-oriented crystal slab with transformed coordinates. Notes ----- - The transformation preserves atomic types and relative positions. - The resulting structure has z as the surface normal. - Periodic boundary conditions apply in x and y directions. - The depth direction (z) is typically non-periodic for surface slabs. 1. Build rotation matrix to align [hkl] direction with z-axis: convert Miller indices to reciprocal lattice vector and calculate rotation matrix R that maps this vector to [0, 0, 1]. 2. Transform all atomic positions using rotation matrix. 3. Create supercell by replicating atoms in x and y directions: determine number of repetitions needed to cover x_extent and y_extent, generate all combinations of translations, and apply translations to create supercell. 4. Filter atoms within depth range: 0 <= z <= depth. 5. Center the slab so z=0 is at the bottom surface. 6. Calculate new cell parameters for the supercell. 7. Return SlicedCrystal with transformed coordinates. Examples -------- >>> import jax.numpy as jnp >>> import rheedium as rh >>> >>> # Load bulk structure >>> bulk = rh.inout.parse_cif("SrTiO3.cif") >>> >>> # Create (111) surface slab >>> slab = rh.ucell.bulk_to_slice( ... bulk_crystal=bulk, ... orientation=jnp.array([1, 1, 1]), ... depth=20.0, ... x_extent=150.0, ... y_extent=150.0, ... ) """ orientation = jnp.asarray(orientation, dtype=jnp.int32) depth = jnp.asarray(depth, dtype=jnp.float64) x_extent = jnp.asarray(x_extent, dtype=jnp.float64) y_extent = jnp.asarray(y_extent, dtype=jnp.float64) cell_vecs: Float[Array, "3 3"] = build_cell_vectors( *bulk_crystal.cell_lengths, *bulk_crystal.cell_angles ) recip_vecs: Float[Array, "3 3"] = reciprocal_lattice_vectors( *bulk_crystal.cell_lengths, *bulk_crystal.cell_angles, in_degrees=True, ) hkl_cart: Float[Array, "3"] = ( orientation[0] * recip_vecs[0] + orientation[1] * recip_vecs[1] + orientation[2] * recip_vecs[2] ) hkl_norm: Float[Array, "3"] = hkl_cart / jnp.linalg.norm(hkl_cart) z_axis: Float[Array, "3"] = jnp.array([0.0, 0.0, 1.0]) rot_axis: Float[Array, "3"] = jnp.cross(hkl_norm, z_axis) rot_axis_norm: Float[Array, ""] = jnp.linalg.norm(rot_axis) cos_angle: Float[Array, ""] = jnp.dot(hkl_norm, z_axis) angle: Float[Array, ""] = jnp.arccos(jnp.clip(cos_angle, -1.0, 1.0)) def _aligned_matrix() -> Float[Array, "3 3"]: return jnp.eye(3) def _rotation_matrix() -> Float[Array, "3 3"]: k: Float[Array, "3"] = rot_axis / (rot_axis_norm + 1e-10) skew: Float[Array, "3 3"] = jnp.array( [ [0.0, -k[2], k[1]], [k[2], 0.0, -k[0]], [-k[1], k[0], 0.0], ] ) return ( jnp.eye(3) + jnp.sin(angle) * skew + (1 - jnp.cos(angle)) * (skew @ skew) ) _rot_threshold: float = 1e-6 rotation_matrix: Float[Array, "3 3"] = lax.cond( rot_axis_norm < _rot_threshold, _aligned_matrix, _rotation_matrix, ) positions_xyz: Float[Array, "N 3"] = bulk_crystal.cart_positions[:, :3] rotated_positions: Float[Array, "N 3"] = positions_xyz @ rotation_matrix.T rotated_cell_vecs: Float[Array, "3 3"] = cell_vecs @ rotation_matrix.T z_projections: Float[Array, "3"] = jnp.abs(rotated_cell_vecs[:, 2]) in_plane_axes: Int[Array, "2"] = jnp.argsort(z_projections)[:2] in_plane_vecs: Float[Array, "2 3"] = rotated_cell_vecs[in_plane_axes] x_repeat_vec: Float[Array, "3"] = in_plane_vecs[0] y_repeat_vec: Float[Array, "3"] = in_plane_vecs[1] cell_x_proj: Float[Array, ""] = jnp.max(jnp.abs(in_plane_vecs[:, 0])) cell_y_proj: Float[Array, ""] = jnp.max(jnp.abs(in_plane_vecs[:, 1])) nx: int = int(jnp.ceil(x_extent / (cell_x_proj + 1e-10))) + 2 ny: int = int(jnp.ceil(y_extent / (cell_y_proj + 1e-10))) + 2 atomic_numbers: Float[Array, "N"] = bulk_crystal.cart_positions[:, 3] ix_vals: Float[Array, "Rx"] = jnp.arange( -nx // 2, nx // 2 + 1, dtype=jnp.float64 ) iy_vals: Float[Array, "Ry"] = jnp.arange( -ny // 2, ny // 2 + 1, dtype=jnp.float64 ) ix_grid: Float[Array, "Rx Ry"] = jnp.repeat( ix_vals[:, None], iy_vals.shape[0], axis=1 ) iy_grid: Float[Array, "Rx Ry"] = jnp.repeat( iy_vals[None, :], ix_vals.shape[0], axis=0 ) ix_flat: Float[Array, "R"] = ix_grid.ravel() iy_flat: Float[Array, "R"] = iy_grid.ravel() n_replicas: int = ix_flat.shape[0] translations: Float[Array, "R 3"] = ( ix_flat[:, None] * x_repeat_vec[None, :] + iy_flat[:, None] * y_repeat_vec[None, :] ) tiled_positions: Float[Array, "R N 3"] = ( rotated_positions[None, :, :] + translations[:, None, :] ) supercell_positions: Float[Array, "M 3"] = tiled_positions.reshape(-1, 3) supercell_atomic_nums: Float[Array, "M"] = jnp.tile( atomic_numbers, n_replicas ) x_min: Float[Array, ""] = supercell_positions[:, 0].min() y_min: Float[Array, ""] = supercell_positions[:, 1].min() z_min: Float[Array, ""] = supercell_positions[:, 2].min() centered_positions: Float[Array, "M 3"] = supercell_positions - jnp.array( [x_min, y_min, z_min] ) x_mask: Bool[Array, "M"] = jnp.logical_and( centered_positions[:, 0] >= 0, centered_positions[:, 0] <= x_extent, ) y_mask: Bool[Array, "M"] = jnp.logical_and( centered_positions[:, 1] >= 0, centered_positions[:, 1] <= y_extent, ) z_mask: Bool[Array, "M"] = jnp.logical_and( centered_positions[:, 2] >= 0, centered_positions[:, 2] <= depth, ) combined_mask: Bool[Array, "M"] = x_mask & y_mask & z_mask filtered_positions: Float[Array, "K 3"] = centered_positions[combined_mask] filtered_atomic_nums: Float[Array, "K"] = supercell_atomic_nums[ combined_mask ] final_positions: Float[Array, "K 4"] = jnp.column_stack( [filtered_positions, filtered_atomic_nums] ) new_cell_lengths: Float[Array, "3"] = jnp.array( [x_extent, y_extent, depth] ) new_cell_angles: Float[Array, "3"] = jnp.array([90.0, 90.0, 90.0]) return create_sliced_crystal( cart_positions=final_positions, cell_lengths=new_cell_lengths, cell_angles=new_cell_angles, orientation=orientation, depth=depth, x_extent=x_extent, y_extent=y_extent, )
__all__: list[str] = [ "atom_scraper", "build_cell_vectors", "bulk_to_slice", "compute_lengths_angles", "generate_reciprocal_points", "get_unit_cell_matrix", "miller_to_reciprocal", "reciprocal_lattice_vectors", "reciprocal_unitcell", ]