r"""Optimization routines for inverse RHEED problems.
Extended Summary
----------------
This module provides general-purpose optimization routines for
reconstruction problems built on differentiable forward models. The
low-level solvers operate on arbitrary JAX pytrees, while the
high-level wrappers target image-matching workflows that compare a
simulated detector image against an experimental one.
Routine Listings
----------------
:class:`ReconstructionResult`
Result container returned by all reconstruction solvers.
:func:`gauss_newton_least_squares`
Gauss-Newton optimizer for least-squares residual functions.
:func:`adam_optimize`
Adam optimizer for arbitrary scalar objectives.
:func:`adagrad_optimize`
Adagrad optimizer for arbitrary scalar objectives.
:func:`gauss_newton_reconstruction`
Image-matching reconstruction using Gauss-Newton.
:func:`adam_reconstruction`
Image-matching reconstruction using Adam.
:func:`adagrad_reconstruction`
Image-matching reconstruction using Adagrad.
"""
import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Any, Callable, NamedTuple, Optional, Tuple
from jax.flatten_util import ravel_pytree
from jax.tree_util import register_pytree_node_class
from jaxtyping import Array, Bool, Float, Int, jaxtyped
from rheedium.types import scalar_float
from .losses import (
weighted_image_residual,
weighted_mean_squared_error,
)
[docs]
@register_pytree_node_class
class ReconstructionResult(NamedTuple):
"""Container for reconstruction outputs and optimization traces.
Attributes
----------
params : Any
Final reconstructed parameter pytree.
objective_history : Float[Array, "N"]
Objective value after each accepted optimization step.
gradient_norm_history : Float[Array, "N"]
L2 norm of the gradient-like search direction at each iteration.
step_norm_history : Float[Array, "N"]
L2 norm of the parameter update applied at each iteration.
iterations : Int[Array, ""]
Number of recorded optimization iterations.
converged : Bool[Array, ""]
True when a convergence tolerance was met before exhausting the
iteration budget.
"""
params: Any
objective_history: Float[Array, "N"]
gradient_norm_history: Float[Array, "N"]
step_norm_history: Float[Array, "N"]
iterations: Int[Array, ""]
converged: Bool[Array, ""]
def tree_flatten(
self,
) -> Tuple[
Tuple[
Any,
Float[Array, "N"],
Float[Array, "N"],
Float[Array, "N"],
Int[Array, ""],
Bool[Array, ""],
],
None,
]:
"""Flatten the PyTree into a tuple of leaves."""
return (
(
self.params,
self.objective_history,
self.gradient_norm_history,
self.step_norm_history,
self.iterations,
self.converged,
),
None,
)
@classmethod
def tree_unflatten(
cls,
aux_data: None,
children: Tuple[
Any,
Float[Array, "N"],
Float[Array, "N"],
Float[Array, "N"],
Int[Array, ""],
Bool[Array, ""],
],
) -> "ReconstructionResult":
"""Unflatten the PyTree into a result instance."""
del aux_data
return cls(*children)
def _tree_l2_norm(tree: Any) -> scalar_float:
"""Compute the Euclidean norm of all leaves in a pytree."""
leaves: list[Any] = jax.tree_util.tree_leaves(tree)
if not leaves:
return jnp.asarray(0.0)
squared_norm: scalar_float = jnp.asarray(0.0)
for leaf in leaves:
leaf_array: Array = jnp.asarray(leaf)
squared_norm = squared_norm + jnp.real(
jnp.vdot(leaf_array, leaf_array)
)
return jnp.sqrt(squared_norm)
def _result_from_history(
params: Any,
objective_history: list[scalar_float],
gradient_norm_history: list[scalar_float],
step_norm_history: list[scalar_float],
converged: bool,
) -> ReconstructionResult:
"""Assemble a result object from Python-side optimization traces."""
return ReconstructionResult(
params=params,
objective_history=jnp.asarray(objective_history),
gradient_norm_history=jnp.asarray(gradient_norm_history),
step_norm_history=jnp.asarray(step_norm_history),
iterations=jnp.asarray(len(objective_history), dtype=jnp.int32),
converged=jnp.asarray(converged),
)
def _apply_postprocess(
simulated_image: Float[Array, "H W"],
postprocess_fn: Optional[
Callable[[Float[Array, "H W"]], Float[Array, "H W"]]
],
) -> Float[Array, "H W"]:
"""Apply an optional post-processing transform to a simulated image."""
if postprocess_fn is None:
return simulated_image
return postprocess_fn(simulated_image)
[docs]
@jaxtyped(typechecker=beartype)
def gauss_newton_least_squares(
initial_params: Any,
residual_fn: Callable[[Any], Array],
damping: scalar_float = 1e-3,
step_scale: scalar_float = 1.0,
max_iterations: int = 25,
tolerance: scalar_float = 1e-6,
) -> ReconstructionResult:
r"""Minimize a least-squares objective with Gauss-Newton iterations.
Extended Summary
----------------
This solver targets objectives of the form
:math:`\min_\theta \|r(\theta)\|_2^2`, where ``residual_fn``
returns the residual vector or tensor. Parameters may be any JAX
pytree; internally they are flattened with
:func:`jax.flatten_util.ravel_pytree`.
Parameters
----------
initial_params : Any
Initial parameter pytree.
residual_fn : Callable[[Any], Array]
Residual function returning any array-shaped residual.
damping : scalar_float, optional
Levenberg-style diagonal damping added to the normal matrix.
Default: 1e-3
step_scale : scalar_float, optional
Scalar multiplier applied to the Gauss-Newton step.
Default: 1.0
max_iterations : int, optional
Maximum number of iterations. Default: 25
tolerance : scalar_float, optional
Convergence threshold applied to the gradient norm and update
norm. Default: 1e-6
Returns
-------
result : ReconstructionResult
Final parameters plus optimization traces.
"""
flat_params: Float[Array, " P"]
unravel_fn: Callable[[Float[Array, "P"]], Any]
flat_params, unravel_fn = ravel_pytree(initial_params)
def flat_residual_fn(
flat_parameter_vector: Float[Array, "P"],
) -> Float[Array, " R"]:
residual: Array = residual_fn(unravel_fn(flat_parameter_vector))
return jnp.ravel(jnp.asarray(residual))
objective_history: list[scalar_float] = []
gradient_norm_history: list[scalar_float] = []
step_norm_history: list[scalar_float] = []
converged: bool = False
for _ in range(max_iterations):
residual_vector: Float[Array, " R"] = flat_residual_fn(flat_params)
objective_value: scalar_float = jnp.mean(residual_vector**2)
jacobian: Float[Array, "R P"] = jax.jacrev(flat_residual_fn)(
flat_params
)
gradient_vector: Float[Array, " P"] = jacobian.T @ residual_vector
gradient_norm: scalar_float = jnp.linalg.norm(gradient_vector)
if bool(gradient_norm <= tolerance):
objective_history.append(objective_value)
gradient_norm_history.append(gradient_norm)
step_norm_history.append(jnp.asarray(0.0))
converged = True
break
normal_matrix: Float[Array, "P P"] = jacobian.T @ jacobian
identity: Float[Array, "P P"] = jnp.eye(
normal_matrix.shape[0], dtype=normal_matrix.dtype
)
step: Float[Array, " P"] = -step_scale * jnp.linalg.solve(
normal_matrix + damping * identity,
gradient_vector,
)
step_norm: scalar_float = jnp.linalg.norm(step)
flat_params = flat_params + step
updated_residual: Float[Array, " R"] = flat_residual_fn(flat_params)
updated_objective: scalar_float = jnp.mean(updated_residual**2)
objective_history.append(updated_objective)
gradient_norm_history.append(gradient_norm)
step_norm_history.append(step_norm)
if bool(step_norm <= tolerance):
converged = True
break
return _result_from_history(
params=unravel_fn(flat_params),
objective_history=objective_history,
gradient_norm_history=gradient_norm_history,
step_norm_history=step_norm_history,
converged=converged,
)
[docs]
@jaxtyped(typechecker=beartype)
def adam_optimize(
initial_params: Any,
objective_fn: Callable[[Any], scalar_float],
learning_rate: scalar_float = 1e-2,
beta1: scalar_float = 0.9,
beta2: scalar_float = 0.999,
epsilon: scalar_float = 1e-8,
max_iterations: int = 250,
tolerance: scalar_float = 1e-6,
) -> ReconstructionResult:
r"""Minimize a scalar objective with the Adam optimizer.
Parameters
----------
initial_params : Any
Initial parameter pytree.
objective_fn : Callable[[Any], scalar_float]
Scalar objective to minimize.
learning_rate : scalar_float, optional
Adam learning rate. Default: 1e-2
beta1 : scalar_float, optional
Exponential decay factor for the first moment. Default: 0.9
beta2 : scalar_float, optional
Exponential decay factor for the second moment. Default: 0.999
epsilon : scalar_float, optional
Denominator stabilizer. Default: 1e-8
max_iterations : int, optional
Maximum number of iterations. Default: 250
tolerance : scalar_float, optional
Convergence threshold applied to the gradient norm and update
norm. Default: 1e-6
Returns
-------
result : ReconstructionResult
Final parameters plus optimization traces.
"""
params: Any = initial_params
first_moment: Any = jax.tree_util.tree_map(jnp.zeros_like, params)
second_moment: Any = jax.tree_util.tree_map(jnp.zeros_like, params)
objective_history: list[scalar_float] = []
gradient_norm_history: list[scalar_float] = []
step_norm_history: list[scalar_float] = []
converged: bool = False
for iteration in range(1, max_iterations + 1):
objective_value: scalar_float
gradients: Any
objective_value, gradients = jax.value_and_grad(objective_fn)(params)
gradient_norm: scalar_float = _tree_l2_norm(gradients)
if bool(gradient_norm <= tolerance):
objective_history.append(objective_value)
gradient_norm_history.append(gradient_norm)
step_norm_history.append(jnp.asarray(0.0))
converged = True
break
first_moment = jax.tree_util.tree_map(
lambda moment, grad: beta1 * moment + (1.0 - beta1) * grad,
first_moment,
gradients,
)
second_moment = jax.tree_util.tree_map(
lambda moment, grad: beta2 * moment + (1.0 - beta2) * grad**2,
second_moment,
gradients,
)
first_bias_correction: scalar_float = 1.0 - beta1**iteration
second_bias_correction: scalar_float = 1.0 - beta2**iteration
first_moment_hat: Any = jax.tree_util.tree_map(
lambda moment, correction=first_bias_correction: (
moment / correction
),
first_moment,
)
second_moment_hat: Any = jax.tree_util.tree_map(
lambda moment, correction=second_bias_correction: (
moment / correction
),
second_moment,
)
step: Any = jax.tree_util.tree_map(
lambda moment, variance: (
-learning_rate * moment / (jnp.sqrt(variance) + epsilon)
),
first_moment_hat,
second_moment_hat,
)
step_norm: scalar_float = _tree_l2_norm(step)
params = jax.tree_util.tree_map(
lambda param, update: param + update,
params,
step,
)
updated_objective: scalar_float = objective_fn(params)
objective_history.append(updated_objective)
gradient_norm_history.append(gradient_norm)
step_norm_history.append(step_norm)
if bool(step_norm <= tolerance):
converged = True
break
return _result_from_history(
params=params,
objective_history=objective_history,
gradient_norm_history=gradient_norm_history,
step_norm_history=step_norm_history,
converged=converged,
)
[docs]
@jaxtyped(typechecker=beartype)
def adagrad_optimize(
initial_params: Any,
objective_fn: Callable[[Any], scalar_float],
learning_rate: scalar_float = 1e-1,
epsilon: scalar_float = 1e-8,
max_iterations: int = 500,
tolerance: scalar_float = 1e-6,
) -> ReconstructionResult:
r"""Minimize a scalar objective with the Adagrad optimizer.
Parameters
----------
initial_params : Any
Initial parameter pytree.
objective_fn : Callable[[Any], scalar_float]
Scalar objective to minimize.
learning_rate : scalar_float, optional
Adagrad base learning rate. Default: 1e-1
epsilon : scalar_float, optional
Denominator stabilizer. Default: 1e-8
max_iterations : int, optional
Maximum number of iterations. Default: 500
tolerance : scalar_float, optional
Convergence threshold applied to the gradient norm and update
norm. Default: 1e-6
Returns
-------
result : ReconstructionResult
Final parameters plus optimization traces.
"""
params: Any = initial_params
accumulator: Any = jax.tree_util.tree_map(jnp.zeros_like, params)
objective_history: list[scalar_float] = []
gradient_norm_history: list[scalar_float] = []
step_norm_history: list[scalar_float] = []
converged: bool = False
for _ in range(max_iterations):
objective_value: scalar_float
gradients: Any
objective_value, gradients = jax.value_and_grad(objective_fn)(params)
gradient_norm: scalar_float = _tree_l2_norm(gradients)
if bool(gradient_norm <= tolerance):
objective_history.append(objective_value)
gradient_norm_history.append(gradient_norm)
step_norm_history.append(jnp.asarray(0.0))
converged = True
break
accumulator = jax.tree_util.tree_map(
lambda state, grad: state + grad**2,
accumulator,
gradients,
)
step: Any = jax.tree_util.tree_map(
lambda grad, state: (
-learning_rate * grad / (jnp.sqrt(state) + epsilon)
),
gradients,
accumulator,
)
step_norm: scalar_float = _tree_l2_norm(step)
params = jax.tree_util.tree_map(
lambda param, update: param + update,
params,
step,
)
updated_objective: scalar_float = objective_fn(params)
objective_history.append(updated_objective)
gradient_norm_history.append(gradient_norm)
step_norm_history.append(step_norm)
if bool(step_norm <= tolerance):
converged = True
break
return _result_from_history(
params=params,
objective_history=objective_history,
gradient_norm_history=gradient_norm_history,
step_norm_history=step_norm_history,
converged=converged,
)
[docs]
@jaxtyped(typechecker=beartype)
def gauss_newton_reconstruction( # noqa: PLR0913
initial_params: Any,
forward_model: Callable[[Any], Float[Array, "H W"]],
experimental_image: Float[Array, "H W"],
weight_map: Optional[Float[Array, "H W"]] = None,
postprocess_fn: Optional[
Callable[[Float[Array, "H W"]], Float[Array, "H W"]]
] = None,
damping: scalar_float = 1e-3,
step_scale: scalar_float = 1.0,
max_iterations: int = 25,
tolerance: scalar_float = 1e-6,
) -> ReconstructionResult:
r"""Reconstruct parameters by least-squares image matching.
Parameters
----------
initial_params : Any
Initial parameter pytree passed to ``forward_model``.
forward_model : Callable[[Any], Float[Array, "H W"]]
Differentiable simulator that maps parameters to a detector
image.
experimental_image : Float[Array, "H W"]
Target detector image.
weight_map : Float[Array, "H W"], optional
Non-negative per-pixel weights for least-squares fitting.
postprocess_fn : Callable[[Float[Array, "H W"]], \
Float[Array, "H W"]], optional
Optional transform applied to each simulated image before it is
compared against ``experimental_image``.
damping : scalar_float, optional
Diagonal damping for the Gauss-Newton normal equations.
Default: 1e-3
step_scale : scalar_float, optional
Scalar multiplier applied to each update. Default: 1.0
max_iterations : int, optional
Maximum number of iterations. Default: 25
tolerance : scalar_float, optional
Convergence threshold. Default: 1e-6
Returns
-------
result : ReconstructionResult
Final parameters plus optimization traces.
"""
def residual_fn(params: Any) -> Float[Array, "H W"]:
simulated_image: Float[Array, "H W"] = _apply_postprocess(
forward_model(params),
postprocess_fn,
)
return weighted_image_residual(
simulated_image=simulated_image,
experimental_image=experimental_image,
weight_map=weight_map,
)
return gauss_newton_least_squares(
initial_params=initial_params,
residual_fn=residual_fn,
damping=damping,
step_scale=step_scale,
max_iterations=max_iterations,
tolerance=tolerance,
)
[docs]
@jaxtyped(typechecker=beartype)
def adam_reconstruction( # noqa: PLR0913
initial_params: Any,
forward_model: Callable[[Any], Float[Array, "H W"]],
experimental_image: Float[Array, "H W"],
weight_map: Optional[Float[Array, "H W"]] = None,
postprocess_fn: Optional[
Callable[[Float[Array, "H W"]], Float[Array, "H W"]]
] = None,
learning_rate: scalar_float = 1e-2,
beta1: scalar_float = 0.9,
beta2: scalar_float = 0.999,
epsilon: scalar_float = 1e-8,
max_iterations: int = 250,
tolerance: scalar_float = 1e-6,
loss_fn: Callable[
[
Float[Array, "H W"],
Float[Array, "H W"],
Optional[Float[Array, "H W"]],
],
scalar_float,
] = weighted_mean_squared_error,
) -> ReconstructionResult:
r"""Reconstruct parameters by minimizing an image-matching loss.
Parameters
----------
initial_params : Any
Initial parameter pytree passed to ``forward_model``.
forward_model : Callable[[Any], Float[Array, "H W"]]
Differentiable simulator that maps parameters to a detector
image.
experimental_image : Float[Array, "H W"]
Target detector image.
weight_map : Float[Array, "H W"], optional
Optional non-negative per-pixel weights.
postprocess_fn : Callable[[Float[Array, "H W"]], \
Float[Array, "H W"]], optional
Optional transform applied to each simulated image before loss
evaluation.
learning_rate : scalar_float, optional
Adam learning rate. Default: 1e-2
beta1 : scalar_float, optional
First-moment decay factor. Default: 0.9
beta2 : scalar_float, optional
Second-moment decay factor. Default: 0.999
epsilon : scalar_float, optional
Denominator stabilizer. Default: 1e-8
max_iterations : int, optional
Maximum number of iterations. Default: 250
tolerance : scalar_float, optional
Convergence threshold. Default: 1e-6
loss_fn : Callable[..., scalar_float], optional
Image loss used for optimization. Default:
:func:`weighted_mean_squared_error`
Returns
-------
result : ReconstructionResult
Final parameters plus optimization traces.
"""
def objective_fn(params: Any) -> scalar_float:
simulated_image: Float[Array, "H W"] = _apply_postprocess(
forward_model(params),
postprocess_fn,
)
return loss_fn(simulated_image, experimental_image, weight_map)
return adam_optimize(
initial_params=initial_params,
objective_fn=objective_fn,
learning_rate=learning_rate,
beta1=beta1,
beta2=beta2,
epsilon=epsilon,
max_iterations=max_iterations,
tolerance=tolerance,
)
[docs]
@jaxtyped(typechecker=beartype)
def adagrad_reconstruction( # noqa: PLR0913
initial_params: Any,
forward_model: Callable[[Any], Float[Array, "H W"]],
experimental_image: Float[Array, "H W"],
weight_map: Optional[Float[Array, "H W"]] = None,
postprocess_fn: Optional[
Callable[[Float[Array, "H W"]], Float[Array, "H W"]]
] = None,
learning_rate: scalar_float = 1e-1,
epsilon: scalar_float = 1e-8,
max_iterations: int = 500,
tolerance: scalar_float = 1e-6,
loss_fn: Callable[
[
Float[Array, "H W"],
Float[Array, "H W"],
Optional[Float[Array, "H W"]],
],
scalar_float,
] = weighted_mean_squared_error,
) -> ReconstructionResult:
r"""Reconstruct parameters by minimizing an image-matching loss.
Parameters
----------
initial_params : Any
Initial parameter pytree passed to ``forward_model``.
forward_model : Callable[[Any], Float[Array, "H W"]]
Differentiable simulator that maps parameters to a detector
image.
experimental_image : Float[Array, "H W"]
Target detector image.
weight_map : Float[Array, "H W"], optional
Optional non-negative per-pixel weights.
postprocess_fn : Callable[[Float[Array, "H W"]], \
Float[Array, "H W"]], optional
Optional transform applied to each simulated image before loss
evaluation.
learning_rate : scalar_float, optional
Adagrad base learning rate. Default: 1e-1
epsilon : scalar_float, optional
Denominator stabilizer. Default: 1e-8
max_iterations : int, optional
Maximum number of iterations. Default: 500
tolerance : scalar_float, optional
Convergence threshold. Default: 1e-6
loss_fn : Callable[..., scalar_float], optional
Image loss used for optimization. Default:
:func:`weighted_mean_squared_error`
Returns
-------
result : ReconstructionResult
Final parameters plus optimization traces.
"""
def objective_fn(params: Any) -> scalar_float:
simulated_image: Float[Array, "H W"] = _apply_postprocess(
forward_model(params),
postprocess_fn,
)
return loss_fn(simulated_image, experimental_image, weight_map)
return adagrad_optimize(
initial_params=initial_params,
objective_fn=objective_fn,
learning_rate=learning_rate,
epsilon=epsilon,
max_iterations=max_iterations,
tolerance=tolerance,
)
__all__: list[str] = [
"ReconstructionResult",
"adagrad_optimize",
"adagrad_reconstruction",
"adam_optimize",
"adam_reconstruction",
"gauss_newton_least_squares",
"gauss_newton_reconstruction",
]