Source code for rheedium.recon.losses

r"""Loss functions and residual builders for inverse RHEED problems.

Extended Summary
----------------
This module provides differentiable helpers for comparing simulated
detector images against experimental images. The functions are written
to stay lightweight and composable so forward models from
:mod:`rheedium.simul` and :mod:`rheedium.procs` can be optimized
through the same reconstruction routines.

Routine Listings
----------------
:func:`weighted_image_residual`
    Build a weighted least-squares residual field between two images.
:func:`weighted_mean_squared_error`
    Compute a normalized weighted mean-squared error.
"""

import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Optional
from jaxtyping import Array, Float, jaxtyped

from rheedium.types import scalar_float


[docs] @jaxtyped(typechecker=beartype) def weighted_image_residual( simulated_image: Float[Array, "H W"], experimental_image: Float[Array, "H W"], weight_map: Optional[Float[Array, "H W"]] = None, ) -> Float[Array, "H W"]: r"""Build a weighted least-squares residual field. Extended Summary ---------------- The returned residual is designed for Gauss-Newton and other least-squares solvers. When a ``weight_map`` is provided it is interpreted as a non-negative per-pixel reliability weight: .. math:: r_{ij} = \sqrt{w_{ij}}\,(I^{\text{sim}}_{ij} - I^{\text{exp}}_{ij}) This convention ensures that minimizing :math:`\sum r_{ij}^2` matches the standard weighted least-squares objective. Parameters ---------- simulated_image : Float[Array, "H W"] Simulated detector image. experimental_image : Float[Array, "H W"] Experimental target image. weight_map : Float[Array, "H W"], optional Non-negative pixel weights. Values of zero exclude pixels from the fit. Negative entries are clipped to zero. Returns ------- residual : Float[Array, "H W"] Weighted residual image with the same shape as the inputs. """ residual: Float[Array, "H W"] = simulated_image - experimental_image if weight_map is None: return residual clipped_weight_map: Float[Array, "H W"] = jnp.maximum(weight_map, 0.0) return jnp.sqrt(clipped_weight_map) * residual
[docs] @jaxtyped(typechecker=beartype) def weighted_mean_squared_error( simulated_image: Float[Array, "H W"], experimental_image: Float[Array, "H W"], weight_map: Optional[Float[Array, "H W"]] = None, ) -> scalar_float: r"""Compute a normalized weighted mean-squared error. Extended Summary ---------------- Without weights this reduces to the ordinary mean-squared error. With weights it computes: .. math:: \mathrm{WMSE} = \frac{\sum_{ij} w_{ij}(I^{\text{sim}}_{ij} - I^{\text{exp}}_{ij})^2} {\max\left(\sum_{ij} w_{ij}, 10^{-12}\right)} Parameters ---------- simulated_image : Float[Array, "H W"] Simulated detector image. experimental_image : Float[Array, "H W"] Experimental target image. weight_map : Float[Array, "H W"], optional Non-negative pixel weights. Negative entries are clipped to zero before normalization. Returns ------- loss : scalar_float Weighted mean-squared error. """ squared_error: Float[Array, "H W"] = ( simulated_image - experimental_image ) ** 2 if weight_map is None: return jnp.mean(squared_error) clipped_weight_map: Float[Array, "H W"] = jnp.maximum(weight_map, 0.0) normalization: scalar_float = jnp.maximum( jnp.sum(clipped_weight_map), 1e-12 ) return jnp.sum(clipped_weight_map * squared_error) / normalization
__all__: list[str] = [ "weighted_image_residual", "weighted_mean_squared_error", ]