"""Functions for unit cell calculations and transformations.
Extended Summary
----------------
This module provides functions for crystallographic unit cell operations
including reciprocal space calculations, lattice transformations, and atom
filtering for specific zones and thicknesses.
Routine Listings
----------------
:func:`bulk_to_slice`
Transform bulk crystal structure into surface-oriented
slab for RHEED simulation.
:func:`reciprocal_unitcell`
Calculate reciprocal unit cell parameters from direct
cell parameters.
:func:`get_unit_cell_matrix`
Build transformation matrix between direct and
reciprocal space.
:func:`build_cell_vectors`
Construct unit cell vectors from lengths and angles.
:func:`compute_lengths_angles`
Compute unit cell lengths and angles from lattice
vectors.
:func:`generate_reciprocal_points`
Generate reciprocal lattice points for given hkl
ranges.
:func:`atom_scraper`
Filter atoms within specified thickness along zone
axis.
:func:`reciprocal_lattice_vectors`
Generate reciprocal lattice basis vectors b₁, b₂, b₃.
:func:`miller_to_reciprocal`
Convert Miller indices to reciprocal space vectors.
Notes
-----
All functions are JAX-compatible and support automatic differentiation.
"""
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Tuple
from jax import lax
from jaxtyping import Array, Bool, Float, Int, Num, jaxtyped
from rheedium.types import (
CrystalStructure,
SlicedCrystal,
create_crystal_structure,
create_sliced_crystal,
scalar_bool,
scalar_float,
scalar_int,
)
[docs]
@jaxtyped(typechecker=beartype)
def reciprocal_unitcell(
a: scalar_float,
b: scalar_float,
c: scalar_float,
alpha: scalar_float,
beta: scalar_float,
gamma: scalar_float,
in_degrees: scalar_bool = True,
out_degrees: scalar_bool = True,
) -> Tuple[Float[Array, "3"], Float[Array, "3"]]:
"""
Calculate reciprocal unit cell parameters from direct cell parameters.
Computes reciprocal lattice parameters (a*, b*, c*, α*, β*, γ*) from
direct lattice parameters using crystallographic relationships.
Parameters
----------
a : scalar_float
Direct cell length a in Angstroms
b : scalar_float
Direct cell length b in Angstroms
c : scalar_float
Direct cell length c in Angstroms
alpha : scalar_float
Direct cell angle α (between b and c axes)
beta : scalar_float
Direct cell angle β (between a and c axes)
gamma : scalar_float
Direct cell angle γ (between a and b axes)
in_degrees : bool
If True, input angles are in degrees. Default: True
out_degrees : bool
If True, output angles are in degrees. Default: True
Returns
-------
reciprocal_lengths : Float[Array, "3"]
Reciprocal cell lengths [a*, b*, c*] in 1/Angstroms
reciprocal_angles : Float[Array, "3"]
Reciprocal cell angles [α*, β*, γ*] in degrees or radians
Notes
-----
1. **Angle Conversion** --
Convert input angles to radians if provided in
degrees.
2. **Volume Calculation** --
Compute unit cell volume using the triple product
formula with cosines of cell angles.
3. **Reciprocal Lengths** --
Derive a*, b*, c* from the volume and direct cell
parameters using crystallographic relationships.
4. **Reciprocal Angles** --
Calculate alpha*, beta*, gamma* from the direct
cell angle cosines and sines.
5. **Output Conversion** --
Convert output angles to degrees if requested.
See Also
--------
reciprocal_lattice_vectors : Generate reciprocal basis vectors.
build_cell_vectors : Build direct lattice vectors.
"""
alpha_rad: Float[Array, ""] = jnp.where(
in_degrees, jnp.deg2rad(alpha), alpha
)
beta_rad: Float[Array, ""] = jnp.where(in_degrees, jnp.deg2rad(beta), beta)
gamma_rad: Float[Array, ""] = jnp.where(
in_degrees, jnp.deg2rad(gamma), gamma
)
cos_alpha: Float[Array, ""] = jnp.cos(alpha_rad)
cos_beta: Float[Array, ""] = jnp.cos(beta_rad)
cos_gamma: Float[Array, ""] = jnp.cos(gamma_rad)
sin_alpha: Float[Array, ""] = jnp.sin(alpha_rad)
sin_beta: Float[Array, ""] = jnp.sin(beta_rad)
sin_gamma: Float[Array, ""] = jnp.sin(gamma_rad)
volume_squared: Float[Array, ""] = (
1.0
- cos_alpha**2
- cos_beta**2
- cos_gamma**2
+ 2.0 * cos_alpha * cos_beta * cos_gamma
)
volume: Float[Array, ""] = a * b * c * jnp.sqrt(volume_squared)
a_star: Float[Array, ""] = 2.0 * jnp.pi * b * c * sin_alpha / volume
b_star: Float[Array, ""] = 2.0 * jnp.pi * a * c * sin_beta / volume
c_star: Float[Array, ""] = 2.0 * jnp.pi * a * b * sin_gamma / volume
cos_alpha_star: Float[Array, ""] = (cos_beta * cos_gamma - cos_alpha) / (
sin_beta * sin_gamma
)
cos_beta_star: Float[Array, ""] = (cos_alpha * cos_gamma - cos_beta) / (
sin_alpha * sin_gamma
)
cos_gamma_star: Float[Array, ""] = (cos_alpha * cos_beta - cos_gamma) / (
sin_alpha * sin_beta
)
alpha_star_rad: Float[Array, ""] = jnp.arccos(
jnp.clip(cos_alpha_star, -1.0, 1.0)
)
beta_star_rad: Float[Array, ""] = jnp.arccos(
jnp.clip(cos_beta_star, -1.0, 1.0)
)
gamma_star_rad: Float[Array, ""] = jnp.arccos(
jnp.clip(cos_gamma_star, -1.0, 1.0)
)
alpha_star: Float[Array, ""] = jnp.where(
out_degrees, jnp.rad2deg(alpha_star_rad), alpha_star_rad
)
beta_star: Float[Array, ""] = jnp.where(
out_degrees, jnp.rad2deg(beta_star_rad), beta_star_rad
)
gamma_star: Float[Array, ""] = jnp.where(
out_degrees, jnp.rad2deg(gamma_star_rad), gamma_star_rad
)
reciprocal_lengths: Float[Array, "3"] = jnp.array([a_star, b_star, c_star])
reciprocal_angles: Float[Array, "3"] = jnp.array(
[alpha_star, beta_star, gamma_star]
)
return reciprocal_lengths, reciprocal_angles
[docs]
@jaxtyped(typechecker=beartype)
def get_unit_cell_matrix(
a: scalar_float,
b: scalar_float,
c: scalar_float,
alpha: scalar_float,
beta: scalar_float,
gamma: scalar_float,
) -> Float[Array, "3 3"]:
r"""Build transformation matrix between direct and reciprocal space.
Parameters
----------
a, b, c : scalar_float
Direct cell lengths in angstroms.
alpha, beta, gamma : scalar_float
Direct cell angles in degrees.
Returns
-------
Float[Array, "3 3"]
Transformation matrix from direct to reciprocal space.
Notes
-----
1. **Angle Conversion** --
Convert cell angles from degrees to radians and
compute their cosines and sines.
2. **Volume Factor** --
Compute the volume factor from the cosines of all
three cell angles.
3. **Matrix Assembly** --
Populate the 3x3 transformation matrix with
elements derived from lengths, angles, and the
volume factor. The matrix maps direct space
coordinates to Cartesian coordinates.
Examples
--------
>>> import rheedium as rh
>>> import jax.numpy as jnp
>>>
>>> # Get transformation matrix for a cubic cell
>>> matrix = get_unit_cell_matrix(
... a=3.0,
... b=3.0,
... c=3.0, # 3 Å cubic cell
... alpha=90.0,
... beta=90.0,
... gamma=90.0,
... )
>>> print(f"Transformation matrix:\n{matrix}")
>>>
>>> # Transform a direct space vector to reciprocal space
>>> direct_vec = jnp.array([1.0, 0.0, 0.0])
>>> recip_vec = direct_vec @ matrix
>>> print(f"Reciprocal vector: {recip_vec}")
See Also
--------
build_cell_vectors : Build direct lattice vectors from parameters.
reciprocal_lattice_vectors : Generate reciprocal basis vectors.
"""
alpha_rad: Float[Array, ""] = jnp.radians(alpha)
beta_rad: Float[Array, ""] = jnp.radians(beta)
gamma_rad: Float[Array, ""] = jnp.radians(gamma)
cos_angles: Float[Array, "3"] = jnp.array(
[jnp.cos(alpha_rad), jnp.cos(beta_rad), jnp.cos(gamma_rad)]
)
sin_angles: Float[Array, "3"] = jnp.array(
[jnp.sin(alpha_rad), jnp.sin(beta_rad), jnp.sin(gamma_rad)]
)
volume_factor: Float[Array, ""] = jnp.sqrt(
1 - jnp.sum(jnp.square(cos_angles)) + (2 * jnp.prod(cos_angles))
)
matrix: Float[Array, "3 3"] = jnp.zeros(shape=(3, 3), dtype=jnp.float64)
matrix = matrix.at[0, 0].set(a)
matrix = matrix.at[0, 1].set(b * cos_angles[2])
matrix = matrix.at[0, 2].set(c * cos_angles[1])
matrix = matrix.at[1, 1].set(b * sin_angles[2])
matrix = matrix.at[1, 2].set(
c * (cos_angles[0] - cos_angles[1] * cos_angles[2]) / sin_angles[2]
)
matrix_assigned: Float[Array, "3 3"] = matrix.at[2, 2].set(
c * volume_factor / sin_angles[2]
)
return matrix_assigned
[docs]
@jaxtyped(typechecker=beartype)
def build_cell_vectors(
a: scalar_float,
b: scalar_float,
c: scalar_float,
alpha: scalar_float,
beta: scalar_float,
gamma: scalar_float,
) -> Float[Array, "3 3"]:
r"""Construct unit cell vectors from lengths and angles.
Parameters
----------
a, b, c : scalar_float
Direct cell lengths in angstroms.
alpha, beta, gamma : scalar_float
Direct cell angles in degrees.
Returns
-------
Float[Array, "3 3"]
Unit cell vectors as rows of 3x3 matrix.
Notes
-----
1. **Angle Conversion** --
Convert alpha, beta, gamma from degrees to radians.
2. **First Vector** --
Place a-vector along the x-axis as [a, 0, 0].
3. **Second Vector** --
Place b-vector in the x-y plane using gamma to
compute x and y components.
4. **Third Vector** --
Compute c-vector components from all three angles,
clipping c_z squared to avoid negative values under
the square root.
5. **Stack Vectors** --
Assemble the three vectors into a 3x3 matrix with
each vector as a row.
Examples
--------
>>> import rheedium as rh
>>> import jax.numpy as jnp
>>>
>>> # Build vectors for a cubic cell
>>> vectors = build_cell_vectors(
... a=3.0,
... b=3.0,
... c=3.0, # 3 Å cubic cell
... alpha=90.0,
... beta=90.0,
... gamma=90.0,
... )
>>> print(f"Cell vectors:\n{vectors}")
>>>
>>> # Calculate cell volume
>>> volume = jnp.linalg.det(vectors)
>>> print(f"Cell volume: {volume}")
See Also
--------
compute_lengths_angles : Inverse operation from vectors to parameters.
reciprocal_lattice_vectors : Build reciprocal lattice vectors.
"""
alpha_rad: Float[Array, ""] = jnp.radians(alpha)
beta_rad: Float[Array, ""] = jnp.radians(beta)
gamma_rad: Float[Array, ""] = jnp.radians(gamma)
a_vec: Float[Array, "3"] = jnp.array([a, 0.0, 0.0])
b_x: Float[Array, ""] = b * jnp.cos(gamma_rad)
b_y: Float[Array, ""] = b * jnp.sin(gamma_rad)
b_vec: Float[Array, "3"] = jnp.array([b_x, b_y, 0.0])
c_x: Float[Array, ""] = c * jnp.cos(beta_rad)
c_y: Float[Array, ""] = c * (
(jnp.cos(alpha_rad) - jnp.cos(beta_rad) * jnp.cos(gamma_rad))
/ jnp.sin(gamma_rad)
)
c_z_sq: Float[Array, ""] = (c**2) - (c_x**2) - (c_y**2)
c_z: Float[Array, ""] = jnp.sqrt(jnp.clip(c_z_sq, min=0.0))
c_vec: Float[Array, "3"] = jnp.array([c_x, c_y, c_z])
cell_vectors: Float[Array, "3 3"] = jnp.stack(
[a_vec, b_vec, c_vec], axis=0
)
return cell_vectors
@jaxtyped(typechecker=beartype)
def compute_lengths_angles(
vectors: Float[Array, "3 3"],
) -> Tuple[Float[Array, "3"], Float[Array, "3"]]:
"""Compute unit cell lengths and angles from lattice vectors.
Parameters
----------
vectors : Float[Array, "3 3"]
Unit cell vectors as rows of 3x3 matrix.
Returns
-------
lengths : Float[Array, "3"]
Unit cell lengths [a, b, c] in angstroms
angles : Float[Array, "3"]
Unit cell angles [α, β, γ] in degrees
Notes
-----
1. **Compute Lengths** --
Calculate the Euclidean norm of each row vector to
obtain a, b, c.
2. **Compute Angles** --
For each pair of vectors, compute the dot product
divided by the product of norms, then take arccos
(clipped to [-1, 1]) to get alpha, beta, gamma.
3. **Convert to Degrees** --
Convert all angles from radians to degrees.
Examples
--------
>>> import rheedium as rh
>>> import jax.numpy as jnp
>>>
>>> # Create some cell vectors
>>> vectors = jnp.array(
... [
... [3.0, 0.0, 0.0], # a vector
... [0.0, 3.0, 0.0], # b vector
... [0.0, 0.0, 3.0], # c vector
... ]
... )
>>>
>>> # Compute lengths and angles
>>> lengths, angles = compute_lengths_angles(vectors)
>>> print(f"Cell lengths: {lengths}")
>>> print(f"Cell angles: {angles}")
See Also
--------
build_cell_vectors : Inverse operation from parameters to vectors.
"""
lengths: Float[Array, "3"] = jnp.linalg.norm(vectors, axis=1)
a_vec: Float[Array, "3"] = vectors[0]
b_vec: Float[Array, "3"] = vectors[1]
c_vec: Float[Array, "3"] = vectors[2]
cos_alpha: Float[Array, ""] = jnp.dot(b_vec, c_vec) / (
lengths[1] * lengths[2]
)
cos_beta: Float[Array, ""] = jnp.dot(a_vec, c_vec) / (
lengths[0] * lengths[2]
)
cos_gamma: Float[Array, ""] = jnp.dot(a_vec, b_vec) / (
lengths[0] * lengths[1]
)
alpha_rad: Float[Array, ""] = jnp.arccos(jnp.clip(cos_alpha, -1.0, 1.0))
beta_rad: Float[Array, ""] = jnp.arccos(jnp.clip(cos_beta, -1.0, 1.0))
gamma_rad: Float[Array, ""] = jnp.arccos(jnp.clip(cos_gamma, -1.0, 1.0))
angles: Float[Array, "3"] = jnp.array(
[jnp.rad2deg(alpha_rad), jnp.rad2deg(beta_rad), jnp.rad2deg(gamma_rad)]
)
return lengths, angles
[docs]
@jaxtyped(typechecker=beartype)
def generate_reciprocal_points(
crystal: CrystalStructure,
hmax: scalar_int,
kmax: scalar_int,
lmax: scalar_int,
in_degrees: scalar_bool = True,
) -> Float[Array, "M 3"]:
r"""Generate reciprocal-lattice vectors based on the crystal structure.
Parameters
----------
crystal : CrystalStructure
Crystal structure to generate points for.
hmax, kmax, lmax : scalar_int
Maximum h, k, l indices to generate.
in_degrees : bool, optional
Whether to use degrees for angles. Default: True.
Returns
-------
Float[Array, "M 3"]
Reciprocal lattice vectors in 1/angstroms.
Notes
-----
1. **Extract Cell Parameters** --
Retrieve lengths and angles from the crystal
structure.
2. **Build Reciprocal Basis** --
Generate reciprocal lattice vectors from the
direct cell parameters.
3. **Generate Index Grid** --
Create h, k, l index ranges and form a meshgrid
of all (h, k, l) combinations.
4. **Transform to Reciprocal Space** --
Convert the flattened Miller index array to
reciprocal space vectors via
``miller_to_reciprocal``.
Examples
--------
>>> import rheedium as rh
>>> import jax.numpy as jnp
>>>
>>> # Load crystal structure from CIF
>>> crystal = parse_cif("path/to/crystal.cif")
>>>
>>> # Generate reciprocal points up to (2,2,1)
>>> G_vectors = generate_reciprocal_points(
... crystal=crystal,
... hmax=2,
... kmax=2,
... lmax=1,
... )
>>> print(f"Number of G vectors: {len(G_vectors)}")
>>> print(f"First few G vectors:\n{G_vectors[:5]}")
See Also
--------
reciprocal_lattice_vectors : Generate reciprocal basis vectors.
miller_to_reciprocal : Convert Miller indices to G vectors.
"""
abc: Num[Array, "3"] = crystal.cell_lengths
angles: Num[Array, "3"] = crystal.cell_angles
a: Float[Array, ""] = abc[0]
b: Float[Array, ""] = abc[1]
c: Float[Array, ""] = abc[2]
alpha: Float[Array, ""] = angles[0]
beta: Float[Array, ""] = angles[1]
gamma: Float[Array, ""] = angles[2]
rec_vectors: Float[Array, "3 3"] = reciprocal_lattice_vectors(
a, b, c, alpha, beta, gamma, in_degrees=in_degrees
)
hs: Int[Array, "n_h"] = jnp.arange(-hmax, hmax + 1)
ks: Int[Array, "n_k"] = jnp.arange(-kmax, kmax + 1)
ls: Int[Array, "n_l"] = jnp.arange(-lmax, lmax + 1)
hh: Int[Array, "n_h n_k n_l"]
kk: Int[Array, "n_h n_k n_l"]
ll: Int[Array, "n_h n_k n_l"]
hh, kk, ll = jnp.meshgrid(hs, ks, ls, indexing="ij")
hkl: Int[Array, "M 3"] = jnp.stack(
[hh.ravel(), kk.ravel(), ll.ravel()], axis=-1
)
g_vectors: Float[Array, "M 3"] = miller_to_reciprocal(hkl, rec_vectors)
return g_vectors
[docs]
@jaxtyped(typechecker=beartype)
def atom_scraper(
crystal: CrystalStructure,
zone_axis: Float[Array, "3"],
thickness: Float[Array, "3"],
) -> CrystalStructure:
"""Filter atoms within specified thickness along zone axis.
Parameters
----------
crystal : CrystalStructure
Crystal structure to filter.
zone_axis : Float[Array, "3"]
Zone axis direction.
thickness : Float[Array, "3"]
Thickness in each direction.
Returns
-------
filtered_crystal : CrystalStructure
Filtered crystal structure.
Notes
-----
1. **Build Cell Vectors** --
Construct direct lattice vectors from the crystal
structure cell parameters.
2. **Normalize Zone Axis** --
Compute a unit vector along the provided zone axis
direction.
3. **Project and Measure** --
Project each atomic position onto the zone axis
and compute distances from the topmost layer.
4. **Adaptive Threshold** --
Determine an adaptive epsilon for top-layer mode
based on the smallest nonzero distance.
5. **Apply Thickness Mask** --
Create a boolean mask selecting atoms within the
projected thickness or within the adaptive
epsilon for zero-thickness (top-layer) mode.
6. **Gather Filtered Atoms** --
Extract fractional and Cartesian positions for
atoms that pass the mask.
7. **Scale Cell Vectors** --
Rescale lattice vectors along the zone axis to
match the new slab height while preserving
perpendicular components.
8. **Build Output Structure** --
Create a new ``CrystalStructure`` with the
filtered positions and rescaled cell parameters.
Examples
--------
>>> import rheedium as rh
>>> import jax.numpy as jnp
>>>
>>> # Load crystal structure
>>> crystal = parse_cif("path/to/crystal.cif")
>>>
>>> # Filter atoms within 12 Å along [111] direction
>>> filtered = atom_scraper(
... crystal=crystal,
... zone_axis=jnp.array([1.0, 1.0, 1.0]),
... thickness=jnp.array([12.0, 12.0, 12.0]),
... )
>>> print(f"Original atoms: {len(crystal.frac_positions)}")
>>> print(f"Filtered atoms: {len(filtered.frac_positions)}")
See Also
--------
build_cell_vectors : Construct unit cell vectors.
compute_lengths_angles : Compute cell parameters from vectors.
create_crystal_structure : Create filtered crystal structure.
"""
orig_cell_vectors: Float[Array, "3 3"] = build_cell_vectors(
crystal.cell_lengths[0],
crystal.cell_lengths[1],
crystal.cell_lengths[2],
crystal.cell_angles[0],
crystal.cell_angles[1],
crystal.cell_angles[2],
)
zone_axis_norm: Float[Array, ""] = jnp.linalg.norm(zone_axis)
zone_axis_hat: Float[Array, "3"] = zone_axis / (zone_axis_norm + 1e-32)
cart_xyz: Float[Array, "n 3"] = crystal.cart_positions[:, :3]
dot_vals: Float[Array, "n"] = jnp.einsum(
"ij,j->i", cart_xyz, zone_axis_hat
)
d_max: Float[Array, ""] = jnp.max(dot_vals)
dist_from_top: Float[Array, "n"] = d_max - dot_vals
distance_cutoff: Float[Array, ""] = 1e-8
positive_distances: Float[Array, "m"] = dist_from_top[
dist_from_top > distance_cutoff
]
adaptive_eps: Float[Array, ""] = jnp.where(
positive_distances.size > 0,
jnp.maximum(1e-3, 2 * jnp.min(positive_distances)),
1e-3,
)
# Project thickness onto zone axis for scalar
# thickness along that direction
thickness_along_axis: Float[Array, ""] = jnp.abs(
jnp.dot(thickness, zone_axis_hat)
)
is_top_layer_mode: Bool[Array, ""] = jnp.isclose(
thickness_along_axis, jnp.asarray(0.0), atol=1e-8
)
mask: Bool[Array, "n"] = jnp.where(
is_top_layer_mode,
dist_from_top <= adaptive_eps,
dist_from_top <= thickness_along_axis,
)
def _gather_valid_positions(
positions: Float[Array, "n 4"], gather_mask: Bool[Array, "n"]
) -> Float[Array, "m 4"]:
return positions[gather_mask]
filtered_frac: Float[Array, "m 4"] = _gather_valid_positions(
crystal.frac_positions, mask
)
filtered_cart: Float[Array, "m 4"] = _gather_valid_positions(
crystal.cart_positions, mask
)
original_height: Float[Array, ""] = jnp.max(dot_vals) - jnp.min(dot_vals)
new_height: Float[Array, ""] = jnp.where(
is_top_layer_mode,
adaptive_eps,
jnp.minimum(thickness_along_axis, original_height),
)
def _scale_vector(
vec: Float[Array, "3"],
zone_axis_hat: Float[Array, "3"],
old_height: Float[Array, ""],
new_height: Float[Array, ""],
) -> Float[Array, "3"]:
height_cutoff: Float[Array, ""] = 1e-32
proj_mag: Float[Array, ""] = jnp.dot(vec, zone_axis_hat)
parallel_comp: Float[Array, "3"] = proj_mag * zone_axis_hat
perp_comp: Float[Array, "3"] = vec - parallel_comp
scale_factor: Float[Array, ""] = jnp.where(
old_height < height_cutoff, 1.0, new_height / old_height
)
scaled_parallel: Float[Array, "3"] = scale_factor * parallel_comp
return scaled_parallel + perp_comp
def _scale_if_needed(
vec: Float[Array, "3"],
zone_axis_hat: Float[Array, "3"],
original_height: Float[Array, ""],
new_height: Float[Array, ""],
) -> Float[Array, "3"]:
needs_scaling: Bool[Array, ""] = (
jnp.abs(jnp.dot(vec, zone_axis_hat)) > distance_cutoff
)
scaled: Float[Array, "3"] = _scale_vector(
vec, zone_axis_hat, original_height, new_height
)
return jnp.where(needs_scaling, scaled, vec)
scaled_vectors: Float[Array, "3 3"] = jnp.stack(
[
_scale_if_needed(
orig_cell_vectors[i],
zone_axis_hat,
original_height,
new_height,
)
for i in range(3)
]
)
new_lengths: Float[Array, "3"]
new_angles: Float[Array, "3"]
new_lengths, new_angles = compute_lengths_angles(scaled_vectors)
filtered_crystal: CrystalStructure = create_crystal_structure(
frac_positions=filtered_frac,
cart_positions=filtered_cart,
cell_lengths=new_lengths,
cell_angles=new_angles,
)
return filtered_crystal
[docs]
@jaxtyped(typechecker=beartype)
def reciprocal_lattice_vectors(
a: scalar_float,
b: scalar_float,
c: scalar_float,
alpha: scalar_float,
beta: scalar_float,
gamma: scalar_float,
in_degrees: scalar_bool = True,
) -> Float[Array, "3 3"]:
"""Generate reciprocal lattice basis vectors b₁, b₂, b₃.
Computes the three reciprocal lattice basis vectors from direct lattice
parameters using the crystallographic relationships:
b₁ = 2π(a₂ × a₃)/(a₁ · (a₂ × a₃))
b₂ = 2π(a₃ × a₁)/(a₁ · (a₂ × a₃))
b₃ = 2π(a₁ × a₂)/(a₁ · (a₂ × a₃))
Parameters
----------
a : scalar_float
Direct cell length a in Angstroms
b : scalar_float
Direct cell length b in Angstroms
c : scalar_float
Direct cell length c in Angstroms
alpha : scalar_float
Direct cell angle α (between b and c axes)
beta : scalar_float
Direct cell angle β (between a and c axes)
gamma : scalar_float
Direct cell angle γ (between a and b axes)
in_degrees : bool
If True, input angles are in degrees. Default: True
Returns
-------
reciprocal_vectors : Float[Array, "3 3"]
Reciprocal lattice vectors as rows of 3x3 matrix
in 1/Angstroms. Each row is a reciprocal basis
vector [b₁, b₂, b₃].
Notes
-----
1. **Angle Conversion** --
Convert input angles to radians if provided in
degrees.
2. **Build Direct Vectors** --
Construct direct lattice vectors a₁, a₂, a₃ via
``build_cell_vectors``.
3. **Compute Volume** --
Calculate the unit cell volume from the triple
product a₁ . (a₂ x a₃).
4. **Cross Products** --
Compute cross products (a₂ x a₃), (a₃ x a₁),
and (a₁ x a₂) for each reciprocal vector.
5. **Scale to Reciprocal Space** --
Multiply each cross product by 2pi/volume to
obtain b₁, b₂, b₃.
6. **Stack Vectors** --
Assemble the three reciprocal vectors into a 3x3
matrix with each vector as a row.
See Also
--------
build_cell_vectors : Build direct lattice vectors.
reciprocal_unitcell : Compute reciprocal cell parameters.
miller_to_reciprocal : Convert Miller indices to G vectors.
"""
alpha_rad: Float[Array, ""] = jnp.where(
in_degrees, jnp.deg2rad(alpha), alpha
)
beta_rad: Float[Array, ""] = jnp.where(in_degrees, jnp.deg2rad(beta), beta)
gamma_rad: Float[Array, ""] = jnp.where(
in_degrees, jnp.deg2rad(gamma), gamma
)
direct_vectors: Float[Array, "3 3"] = build_cell_vectors(
a,
b,
c,
jnp.rad2deg(alpha_rad),
jnp.rad2deg(beta_rad),
jnp.rad2deg(gamma_rad),
)
a_vec: Float[Array, "3"] = direct_vectors[0]
b_vec: Float[Array, "3"] = direct_vectors[1]
c_vec: Float[Array, "3"] = direct_vectors[2]
cross_b_c: Float[Array, "3"] = jnp.cross(b_vec, c_vec)
cross_c_a: Float[Array, "3"] = jnp.cross(c_vec, a_vec)
cross_a_b: Float[Array, "3"] = jnp.cross(a_vec, b_vec)
volume: Float[Array, ""] = jnp.dot(a_vec, cross_b_c)
two_pi: Float[Array, ""] = 2.0 * jnp.pi
scale_factor: Float[Array, ""] = two_pi / volume
b1_vec: Float[Array, "3"] = scale_factor * cross_b_c
b2_vec: Float[Array, "3"] = scale_factor * cross_c_a
b3_vec: Float[Array, "3"] = scale_factor * cross_a_b
reciprocal_vectors: Float[Array, "3 3"] = jnp.stack(
[b1_vec, b2_vec, b3_vec], axis=0
)
return reciprocal_vectors
[docs]
@jaxtyped(typechecker=beartype)
def miller_to_reciprocal(
hkl: Int[Array, "... 3"],
reciprocal_vectors: Float[Array, "3 3"],
) -> Float[Array, "... 3"]:
"""Convert Miller indices to reciprocal space vectors.
Transforms Miller indices (h,k,l) to reciprocal space vectors G
using the reciprocal lattice basis vectors. Each reciprocal vector
is computed as G = h*b₁ + k*b₂ + l*b₃ where b₁, b₂, b₃ are the
reciprocal lattice basis vectors.
Parameters
----------
hkl : Int[Array, "... 3"]
Miller indices with shape (..., 3) where the last dimension
contains [h, k, l] values. Can be a single set of indices or
a batch of multiple indices.
reciprocal_vectors : Float[Array, "3 3"]
Reciprocal lattice basis vectors as rows of 3x3 matrix in
1/Angstroms, as returned by reciprocal_lattice_vectors function
Returns
-------
g_vectors : Float[Array, "... 3"]
Reciprocal space vectors in 1/Angstroms with same batch shape
as input hkl indices
Notes
-----
1. **Cast to Float** --
Convert integer Miller indices to float for
computation with reciprocal vectors.
2. **Extract Basis Vectors** --
Retrieve individual reciprocal basis vectors
b₁, b₂, b₃ from the input matrix rows.
3. **Linear Combination** --
Compute G = h*b₁ + k*b₂ + l*b₃ using
element-wise broadcasting for efficient batched
computation over all (h, k, l) triplets.
See Also
--------
reciprocal_lattice_vectors : Generate reciprocal basis vectors.
generate_reciprocal_points : Generate G vectors from crystal structure.
"""
hkl_float: Float[Array, "... 3"] = jnp.asarray(hkl, dtype=jnp.float64)
b1_vec: Float[Array, "3"] = reciprocal_vectors[0]
b2_vec: Float[Array, "3"] = reciprocal_vectors[1]
b3_vec: Float[Array, "3"] = reciprocal_vectors[2]
h_component: Float[Array, "..."] = hkl_float[..., 0]
k_component: Float[Array, "..."] = hkl_float[..., 1]
l_component: Float[Array, "..."] = hkl_float[..., 2]
h_contribution: Float[Array, "... 3"] = (
h_component[..., jnp.newaxis] * b1_vec
)
k_contribution: Float[Array, "... 3"] = (
k_component[..., jnp.newaxis] * b2_vec
)
l_contribution: Float[Array, "... 3"] = (
l_component[..., jnp.newaxis] * b3_vec
)
g_vectors: Float[Array, "... 3"] = (
h_contribution + k_contribution + l_contribution
)
return g_vectors
[docs]
@jaxtyped(typechecker=beartype)
def bulk_to_slice( # noqa: PLR0915
bulk_crystal: CrystalStructure,
orientation: Int[Array, "3"],
depth: scalar_float,
x_extent: scalar_float = 150.0,
y_extent: scalar_float = 150.0,
) -> SlicedCrystal:
"""Transform a bulk crystal structure into a surface-oriented slab.
This function takes a bulk crystal structure and creates a surface
slab oriented along the specified Miller indices. The slab is
extended in the x and y directions to create a large surface area
suitable for RHEED simulations, with atoms selected within a
specified depth from the surface.
Parameters
----------
bulk_crystal : CrystalStructure
Bulk crystal structure from CIF file or other source.
orientation : Int[Array, "3"]
Miller indices [h, k, l] defining the desired surface
orientation. Example: [1, 1, 1] for (111) surface,
[0, 0, 1] for (001).
depth : scalar_float
Depth of atoms to include perpendicular to surface in
Angstroms. Typically 10-30 Angstroms to capture surface
effects.
x_extent : scalar_float, optional
Lateral extent in x-direction in Angstroms. Default: 150.0.
Should be >= 100 Angstroms for realistic RHEED simulations.
y_extent : scalar_float, optional
Lateral extent in y-direction in Angstroms. Default: 150.0.
Should be >= 100 Angstroms for realistic RHEED simulations.
Returns
-------
sliced_crystal : SlicedCrystal
Surface-oriented crystal slab with transformed coordinates.
Notes
-----
- The transformation preserves atomic types and relative
positions.
- The resulting structure has z as the surface normal.
- Periodic boundary conditions apply in x and y
directions.
- The depth direction (z) is typically non-periodic for
surface slabs.
1. Build rotation matrix to align [hkl] direction with
z-axis: convert Miller indices to reciprocal lattice
vector and calculate rotation matrix R that maps this
vector to [0, 0, 1].
2. Transform all atomic positions using rotation matrix.
3. Create supercell by replicating atoms in x and y
directions: determine number of repetitions needed
to cover x_extent and y_extent, generate all
combinations of translations, and apply translations
to create supercell.
4. Filter atoms within depth range: 0 <= z <= depth.
5. Center the slab so z=0 is at the bottom surface.
6. Calculate new cell parameters for the supercell.
7. Return SlicedCrystal with transformed coordinates.
Examples
--------
>>> import jax.numpy as jnp
>>> import rheedium as rh
>>>
>>> # Load bulk structure
>>> bulk = rh.inout.parse_cif("SrTiO3.cif")
>>>
>>> # Create (111) surface slab
>>> slab = rh.ucell.bulk_to_slice(
... bulk_crystal=bulk,
... orientation=jnp.array([1, 1, 1]),
... depth=20.0,
... x_extent=150.0,
... y_extent=150.0,
... )
"""
orientation = jnp.asarray(orientation, dtype=jnp.int32)
depth = jnp.asarray(depth, dtype=jnp.float64)
x_extent = jnp.asarray(x_extent, dtype=jnp.float64)
y_extent = jnp.asarray(y_extent, dtype=jnp.float64)
cell_vecs: Float[Array, "3 3"] = build_cell_vectors(
*bulk_crystal.cell_lengths, *bulk_crystal.cell_angles
)
recip_vecs: Float[Array, "3 3"] = reciprocal_lattice_vectors(
*bulk_crystal.cell_lengths,
*bulk_crystal.cell_angles,
in_degrees=True,
)
hkl_cart: Float[Array, "3"] = (
orientation[0] * recip_vecs[0]
+ orientation[1] * recip_vecs[1]
+ orientation[2] * recip_vecs[2]
)
hkl_norm: Float[Array, "3"] = hkl_cart / jnp.linalg.norm(hkl_cart)
z_axis: Float[Array, "3"] = jnp.array([0.0, 0.0, 1.0])
rot_axis: Float[Array, "3"] = jnp.cross(hkl_norm, z_axis)
rot_axis_norm: Float[Array, ""] = jnp.linalg.norm(rot_axis)
cos_angle: Float[Array, ""] = jnp.dot(hkl_norm, z_axis)
angle: Float[Array, ""] = jnp.arccos(jnp.clip(cos_angle, -1.0, 1.0))
def _aligned_matrix() -> Float[Array, "3 3"]:
return jnp.eye(3)
def _rotation_matrix() -> Float[Array, "3 3"]:
k: Float[Array, "3"] = rot_axis / (rot_axis_norm + 1e-10)
skew: Float[Array, "3 3"] = jnp.array(
[
[0.0, -k[2], k[1]],
[k[2], 0.0, -k[0]],
[-k[1], k[0], 0.0],
]
)
return (
jnp.eye(3)
+ jnp.sin(angle) * skew
+ (1 - jnp.cos(angle)) * (skew @ skew)
)
_rot_threshold: float = 1e-6
rotation_matrix: Float[Array, "3 3"] = lax.cond(
rot_axis_norm < _rot_threshold,
_aligned_matrix,
_rotation_matrix,
)
positions_xyz: Float[Array, "N 3"] = bulk_crystal.cart_positions[:, :3]
rotated_positions: Float[Array, "N 3"] = positions_xyz @ rotation_matrix.T
rotated_cell_vecs: Float[Array, "3 3"] = cell_vecs @ rotation_matrix.T
z_projections: Float[Array, "3"] = jnp.abs(rotated_cell_vecs[:, 2])
in_plane_axes: Int[Array, "2"] = jnp.argsort(z_projections)[:2]
in_plane_vecs: Float[Array, "2 3"] = rotated_cell_vecs[in_plane_axes]
x_repeat_vec: Float[Array, "3"] = in_plane_vecs[0]
y_repeat_vec: Float[Array, "3"] = in_plane_vecs[1]
cell_x_proj: Float[Array, ""] = jnp.max(jnp.abs(in_plane_vecs[:, 0]))
cell_y_proj: Float[Array, ""] = jnp.max(jnp.abs(in_plane_vecs[:, 1]))
nx: int = int(jnp.ceil(x_extent / (cell_x_proj + 1e-10))) + 2
ny: int = int(jnp.ceil(y_extent / (cell_y_proj + 1e-10))) + 2
atomic_numbers: Float[Array, "N"] = bulk_crystal.cart_positions[:, 3]
ix_vals: Float[Array, "Rx"] = jnp.arange(
-nx // 2, nx // 2 + 1, dtype=jnp.float64
)
iy_vals: Float[Array, "Ry"] = jnp.arange(
-ny // 2, ny // 2 + 1, dtype=jnp.float64
)
ix_grid: Float[Array, "Rx Ry"] = jnp.repeat(
ix_vals[:, None], iy_vals.shape[0], axis=1
)
iy_grid: Float[Array, "Rx Ry"] = jnp.repeat(
iy_vals[None, :], ix_vals.shape[0], axis=0
)
ix_flat: Float[Array, "R"] = ix_grid.ravel()
iy_flat: Float[Array, "R"] = iy_grid.ravel()
n_replicas: int = ix_flat.shape[0]
translations: Float[Array, "R 3"] = (
ix_flat[:, None] * x_repeat_vec[None, :]
+ iy_flat[:, None] * y_repeat_vec[None, :]
)
tiled_positions: Float[Array, "R N 3"] = (
rotated_positions[None, :, :] + translations[:, None, :]
)
supercell_positions: Float[Array, "M 3"] = tiled_positions.reshape(-1, 3)
supercell_atomic_nums: Float[Array, "M"] = jnp.tile(
atomic_numbers, n_replicas
)
x_min: Float[Array, ""] = supercell_positions[:, 0].min()
y_min: Float[Array, ""] = supercell_positions[:, 1].min()
z_min: Float[Array, ""] = supercell_positions[:, 2].min()
centered_positions: Float[Array, "M 3"] = supercell_positions - jnp.array(
[x_min, y_min, z_min]
)
x_mask: Bool[Array, "M"] = jnp.logical_and(
centered_positions[:, 0] >= 0,
centered_positions[:, 0] <= x_extent,
)
y_mask: Bool[Array, "M"] = jnp.logical_and(
centered_positions[:, 1] >= 0,
centered_positions[:, 1] <= y_extent,
)
z_mask: Bool[Array, "M"] = jnp.logical_and(
centered_positions[:, 2] >= 0,
centered_positions[:, 2] <= depth,
)
combined_mask: Bool[Array, "M"] = x_mask & y_mask & z_mask
filtered_positions: Float[Array, "K 3"] = centered_positions[combined_mask]
filtered_atomic_nums: Float[Array, "K"] = supercell_atomic_nums[
combined_mask
]
final_positions: Float[Array, "K 4"] = jnp.column_stack(
[filtered_positions, filtered_atomic_nums]
)
new_cell_lengths: Float[Array, "3"] = jnp.array(
[x_extent, y_extent, depth]
)
new_cell_angles: Float[Array, "3"] = jnp.array([90.0, 90.0, 90.0])
return create_sliced_crystal(
cart_positions=final_positions,
cell_lengths=new_cell_lengths,
cell_angles=new_cell_angles,
orientation=orientation,
depth=depth,
x_extent=x_extent,
y_extent=y_extent,
)
__all__: list[str] = [
"atom_scraper",
"build_cell_vectors",
"bulk_to_slice",
"compute_lengths_angles",
"generate_reciprocal_points",
"get_unit_cell_matrix",
"miller_to_reciprocal",
"reciprocal_lattice_vectors",
"reciprocal_unitcell",
]