r"""Inverse recovery of orientation distributions from experimental RHEED.
Extended Summary
----------------
This module solves the inverse problem: given experimental RHEED data,
recover a probability distribution over domain orientations. The current
public API fits the discrete variant weights, together with a shared
Gaussian mosaic broadening, on a fixed candidate angle support.
The forward model is:
.. math::
I_\mathrm{sim} = \int P(\theta)\, I(\theta)\, d\theta
\approx \sum_i w_i\, I(\theta_i)
The inverse problem minimizes a masked image-matching loss with optional
weight regularization and entropy stabilization.
Routine Listings
----------------
:class:`OrientationFitResult`
Container for orientation-fitting outputs and diagnostics.
:func:`orientation_loss`
Compute loss between an observed pattern and a trial distribution.
:func:`fit_orientation_weights`
Optimize discrete orientation weights on a fixed support set.
:func:`compute_fisher_information`
Estimate Fisher information for the fitted weight logits.
:func:`estimate_weight_uncertainty`
Propagate Fisher information to 1σ weight uncertainties.
Notes
-----
The inverse problem is most informative when:
1. Different orientations produce distinguishable patterns.
2. The number of fitted orientation parameters is smaller than the
number of independent detector pixels.
3. Regularization prevents overfitting to noise.
"""
import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Callable, Final, NamedTuple, Optional, Tuple
from jax.tree_util import register_pytree_node_class
from jaxtyping import Array, Bool, Float, Int, jaxtyped
from rheedium.types import (
OrientationDistribution,
create_orientation_distribution,
float_jax_image,
integrate_over_orientation,
scalar_float,
scalar_int,
)
_PROBABILITY_EPS: Final[float] = 1e-10
_DEFAULT_MOSAIC_PARAM: Final[float] = -5.0
_FISHER_REGULARIZATION: Final[float] = 1e-6
_ADAM_BETA1: Final[float] = 0.9
_ADAM_BETA2: Final[float] = 0.999
_ADAM_EPSILON: Final[float] = 1e-8
class _OrientationWeightParameters(NamedTuple):
"""Internal unconstrained parameterization for weight fitting."""
weight_logits: Float[Array, "M"]
mosaic_param: Float[Array, ""]
def tree_flatten(
self,
) -> Tuple[
Tuple[Float[Array, "M"], Float[Array, ""]],
None,
]:
"""Flatten for JAX PyTree support."""
return ((self.weight_logits, self.mosaic_param), None)
@classmethod
def tree_unflatten(
cls,
aux_data: None,
children: Tuple[Float[Array, "M"], Float[Array, ""]],
) -> "_OrientationWeightParameters":
"""Unflatten from a JAX PyTree."""
del aux_data
return cls(*children)
class _OrientationAdamState(NamedTuple):
"""Internal Adam optimizer state for orientation fitting."""
first_logits: Float[Array, "M"]
second_logits: Float[Array, "M"]
first_mosaic: Float[Array, ""]
second_mosaic: Float[Array, ""]
def tree_flatten(
self,
) -> Tuple[
Tuple[
Float[Array, "M"],
Float[Array, "M"],
Float[Array, ""],
Float[Array, ""],
],
None,
]:
"""Flatten for JAX PyTree support."""
return (
(
self.first_logits,
self.second_logits,
self.first_mosaic,
self.second_mosaic,
),
None,
)
@classmethod
def tree_unflatten(
cls,
aux_data: None,
children: Tuple[
Float[Array, "M"],
Float[Array, "M"],
Float[Array, ""],
Float[Array, ""],
],
) -> "_OrientationAdamState":
"""Unflatten from a JAX PyTree."""
del aux_data
return cls(*children)
_OrientationObjectiveGradFn = Callable[
[_OrientationWeightParameters],
Tuple[Float[Array, ""], _OrientationWeightParameters],
]
_OrientationOptimizerCarry = Tuple[
_OrientationWeightParameters,
_OrientationAdamState,
Bool[Array, ""],
scalar_int,
Float[Array, ""],
]
_OrientationOptimizerStep = Tuple[
_OrientationOptimizerCarry,
Float[Array, ""],
]
_OrientationOptimizerResult = Tuple[
_OrientationWeightParameters,
Float[Array, "N_steps_full"],
scalar_int,
Float[Array, ""],
Bool[Array, ""],
]
[docs]
@register_pytree_node_class
class OrientationFitResult(NamedTuple):
"""Results from orientation distribution fitting.
Attributes
----------
fitted_distribution : OrientationDistribution
Recovered orientation distribution.
final_loss : Float[Array, ""]
Final scalar loss value.
loss_history : Float[Array, "N_steps"]
Loss value after each optimizer step.
converged : bool
Whether the optimizer met its stopping tolerance.
n_iterations : int
Number of recorded optimization iterations.
residual_pattern : float_jax_image
Difference ``I_observed - I_fitted`` for diagnostics.
"""
fitted_distribution: OrientationDistribution
final_loss: Float[Array, ""]
loss_history: Float[Array, "N_steps"]
converged: bool
n_iterations: int
residual_pattern: float_jax_image
def tree_flatten(
self,
) -> Tuple[
Tuple[
OrientationDistribution,
Float[Array, ""],
Float[Array, "N_steps"],
float_jax_image,
],
Tuple[bool, int],
]:
"""Flatten for JAX PyTree support."""
return (
(
self.fitted_distribution,
self.final_loss,
self.loss_history,
self.residual_pattern,
),
(self.converged, self.n_iterations),
)
@classmethod
def tree_unflatten(
cls,
aux_data: Tuple[bool, int],
children: Tuple[
OrientationDistribution,
Float[Array, ""],
Float[Array, "N_steps"],
float_jax_image,
],
) -> "OrientationFitResult":
"""Unflatten from a JAX PyTree."""
return cls(
fitted_distribution=children[0],
final_loss=children[1],
loss_history=children[2],
converged=aux_data[0],
n_iterations=aux_data[1],
residual_pattern=children[3],
)
@jaxtyped(typechecker=beartype)
def _normalize_pattern(
pattern: float_jax_image,
mask: Optional[float_jax_image] = None,
) -> float_jax_image:
"""Normalize a detector image to unit sum within an optional mask."""
mask_array: float_jax_image
if mask is None:
mask_array = jnp.ones_like(pattern)
else:
mask_array = jnp.asarray(mask, dtype=jnp.float64)
masked_pattern: float_jax_image = pattern * mask_array
total_intensity: Float[Array, ""] = (
jnp.sum(masked_pattern) + _PROBABILITY_EPS
)
return masked_pattern / total_intensity
@jaxtyped(typechecker=beartype)
def _normalize_probability_weights(
weights: Float[Array, "M"],
) -> Float[Array, "M"]:
"""Clip and renormalize weights to a valid probability simplex."""
clipped_weights: Float[Array, "M"] = jnp.clip(
jnp.asarray(weights, dtype=jnp.float64),
0.0,
None,
)
weight_sum: Float[Array, ""] = jnp.sum(clipped_weights)
n_weights: int = clipped_weights.shape[0]
uniform_weights: Float[Array, "M"] = (
jnp.ones(n_weights, dtype=jnp.float64) / n_weights
)
return jax.lax.cond(
weight_sum > 0.0,
lambda: clipped_weights / weight_sum,
lambda: uniform_weights,
)
@jaxtyped(typechecker=beartype)
def _softmax_weights(weight_logits: Float[Array, "M"]) -> Float[Array, "M"]:
"""Convert unconstrained logits to normalized positive weights."""
shifted_logits: Float[Array, "M"] = weight_logits - jnp.max(weight_logits)
exp_logits: Float[Array, "M"] = jnp.exp(shifted_logits)
softmax: Float[Array, "M"] = exp_logits / (
jnp.sum(exp_logits) + _PROBABILITY_EPS
)
return softmax
@jaxtyped(typechecker=beartype)
def _softplus(x: Float[Array, ""]) -> Float[Array, ""]:
"""Softplus transform for non-negative scalar parameters."""
return jnp.log1p(jnp.exp(x))
@jaxtyped(typechecker=beartype)
def _sanitize_distribution(
distribution: OrientationDistribution,
) -> OrientationDistribution:
"""Return a numerically safe orientation distribution."""
return create_orientation_distribution(
angles_deg=distribution.discrete_angles_deg,
weights=distribution.discrete_weights,
mosaic_fwhm_deg=distribution.mosaic_fwhm_deg,
distribution_id=distribution.distribution_id,
)
@jaxtyped(typechecker=beartype)
def _distribution_from_parameters(
candidate_angles_deg: Float[Array, "M"],
params: _OrientationWeightParameters,
) -> OrientationDistribution:
"""Build a physical distribution from unconstrained optimizer params."""
return create_orientation_distribution(
angles_deg=candidate_angles_deg,
weights=_softmax_weights(params.weight_logits),
mosaic_fwhm_deg=_softplus(params.mosaic_param),
distribution_id=None,
)
@jaxtyped(typechecker=beartype)
def _simulate_distribution_pattern(
simulate_fn: Callable[[scalar_float], float_jax_image],
distribution: OrientationDistribution,
n_mosaic_points: scalar_int,
) -> float_jax_image:
"""Evaluate the forward model for one orientation distribution."""
return integrate_over_orientation(
simulate_fn,
_sanitize_distribution(distribution),
n_mosaic_points=n_mosaic_points,
)
@jaxtyped(typechecker=beartype)
def _parameter_norm(params: _OrientationWeightParameters) -> Float[Array, ""]:
"""Compute the Euclidean norm of the fit parameter pytree."""
logit_norm_sq: Float[Array, ""] = jnp.real(
jnp.vdot(params.weight_logits, params.weight_logits)
)
mosaic_norm_sq: Float[Array, ""] = jnp.square(params.mosaic_param)
return jnp.sqrt(logit_norm_sq + mosaic_norm_sq)
@jaxtyped(typechecker=beartype)
def _adam_update(
params: _OrientationWeightParameters,
gradients: _OrientationWeightParameters,
optimizer_state: _OrientationAdamState,
learning_rate: scalar_float,
iteration: scalar_int,
) -> Tuple[
_OrientationWeightParameters,
_OrientationAdamState,
Float[Array, ""],
]:
"""Apply one Adam step to the orientation fit parameters."""
first_logits: Float[Array, "M"] = (
_ADAM_BETA1 * optimizer_state.first_logits
+ (1.0 - _ADAM_BETA1) * gradients.weight_logits
)
second_logits: Float[Array, "M"] = (
_ADAM_BETA2 * optimizer_state.second_logits
+ (1.0 - _ADAM_BETA2) * gradients.weight_logits**2
)
first_mosaic: Float[Array, ""] = (
_ADAM_BETA1 * optimizer_state.first_mosaic
+ (1.0 - _ADAM_BETA1) * gradients.mosaic_param
)
second_mosaic: Float[Array, ""] = (
_ADAM_BETA2 * optimizer_state.second_mosaic
+ (1.0 - _ADAM_BETA2) * gradients.mosaic_param**2
)
first_bias_correction: Float[Array, ""] = 1.0 - _ADAM_BETA1**iteration
second_bias_correction: Float[Array, ""] = 1.0 - _ADAM_BETA2**iteration
first_logits_hat: Float[Array, "M"] = first_logits / first_bias_correction
second_logits_hat: Float[Array, "M"] = (
second_logits / second_bias_correction
)
first_mosaic_hat: Float[Array, ""] = first_mosaic / first_bias_correction
second_mosaic_hat: Float[Array, ""] = (
second_mosaic / second_bias_correction
)
logit_step: Float[Array, "M"] = (
-learning_rate
* first_logits_hat
/ (jnp.sqrt(second_logits_hat) + _ADAM_EPSILON)
)
mosaic_step: Float[Array, ""] = (
-learning_rate
* first_mosaic_hat
/ (jnp.sqrt(second_mosaic_hat) + _ADAM_EPSILON)
)
updated_params: _OrientationWeightParameters = (
_OrientationWeightParameters(
weight_logits=params.weight_logits + logit_step,
mosaic_param=params.mosaic_param + mosaic_step,
)
)
updated_state: _OrientationAdamState = _OrientationAdamState(
first_logits=first_logits,
second_logits=second_logits,
first_mosaic=first_mosaic,
second_mosaic=second_mosaic,
)
step_norm: Float[Array, ""] = _parameter_norm(
_OrientationWeightParameters(
weight_logits=logit_step,
mosaic_param=mosaic_step,
)
)
return updated_params, updated_state, step_norm
@jaxtyped(typechecker=beartype)
def _prepare_pattern_for_loss(
pattern: float_jax_image,
mask: Optional[float_jax_image] = None,
normalize: bool = True,
) -> float_jax_image:
"""Apply masking and optional normalization to a detector pattern."""
if normalize:
return _normalize_pattern(pattern, mask)
if mask is None:
return jnp.asarray(pattern, dtype=jnp.float64)
return jnp.asarray(pattern, dtype=jnp.float64) * jnp.asarray(
mask,
dtype=jnp.float64,
)
[docs]
@jaxtyped(typechecker=beartype)
def orientation_loss(
distribution: OrientationDistribution,
simulate_fn: Callable[[scalar_float], float_jax_image],
observed_pattern: float_jax_image,
mask: Optional[float_jax_image] = None,
normalize: bool = True,
regularization_strength: scalar_float = 0.0,
entropy_weight: scalar_float = 0.0,
reference_weights: Optional[Float[Array, "M"]] = None,
n_mosaic_points: scalar_int = 7,
) -> Float[Array, ""]:
"""Compute loss between observed and simulated patterns.
Description
-----------
Computes a masked mean-squared error between an observed detector
image and the incoherently averaged forward model implied by an
:class:`OrientationDistribution`.
Parameters
----------
distribution : OrientationDistribution
Trial orientation distribution.
simulate_fn : Callable[[scalar_float], Float[Array, "H W"]]
Forward simulation function mapping ``phi_deg`` to a detector
pattern.
observed_pattern : Float[Array, "H W"]
Experimental or synthetic detector image to match.
mask : Optional[Float[Array, "H W"]], optional
Pixel mask for loss computation. Default: all pixels.
normalize : bool, optional
If True, normalize both patterns before comparison. Default: True
regularization_strength : scalar_float, optional
L2 penalty on deviations from ``reference_weights``. If
``reference_weights`` is omitted, a uniform distribution is used.
Default: 0.0
entropy_weight : scalar_float, optional
Weight on the entropy bonus ``-H(w)``. Positive values discourage
collapse onto a single orientation. Default: 0.0
reference_weights : Optional[Float[Array, "M"]], optional
Reference weight vector for the L2 penalty. Default: uniform
weights on the candidate support.
n_mosaic_points : scalar_int, optional
Quadrature points for mosaic broadening. Default: 7
Returns
-------
loss : Float[Array, ""]
Scalar objective value.
"""
sanitized_distribution: OrientationDistribution = _sanitize_distribution(
distribution
)
simulated_pattern: float_jax_image = _simulate_distribution_pattern(
simulate_fn,
sanitized_distribution,
n_mosaic_points=n_mosaic_points,
)
prepared_observed: float_jax_image = _prepare_pattern_for_loss(
observed_pattern,
mask=mask,
normalize=normalize,
)
prepared_simulated: float_jax_image = _prepare_pattern_for_loss(
simulated_pattern,
mask=mask,
normalize=normalize,
)
mask_array: float_jax_image
if mask is None:
mask_array = jnp.ones_like(observed_pattern)
else:
mask_array = jnp.asarray(mask, dtype=jnp.float64)
n_pixels: Float[Array, ""] = jnp.sum(mask_array) + _PROBABILITY_EPS
mse: Float[Array, ""] = (
jnp.sum(jnp.square(prepared_observed - prepared_simulated)) / n_pixels
)
weights: Float[Array, "M"] = sanitized_distribution.discrete_weights
target_weights: Float[Array, "M"]
if reference_weights is None:
n_orientations: int = weights.shape[0]
target_weights = jnp.ones(n_orientations, dtype=jnp.float64) / (
n_orientations
)
else:
target_weights = _normalize_probability_weights(reference_weights)
l2_penalty: Float[Array, ""] = jnp.sum(
jnp.square(weights - target_weights)
)
entropy: Float[Array, ""] = -jnp.sum(
weights * jnp.log(weights + _PROBABILITY_EPS)
)
return (
mse + regularization_strength * l2_penalty - entropy_weight * entropy
)
[docs]
@jaxtyped(typechecker=beartype)
def fit_orientation_weights( # noqa: PLR0913, PLR0915
observed_pattern: float_jax_image,
simulate_fn: Callable[[scalar_float], float_jax_image],
candidate_angles_deg: Float[Array, "M"],
mask: Optional[float_jax_image] = None,
initial_weights: Optional[Float[Array, "M"]] = None,
learning_rate: scalar_float = 0.1,
n_iterations: scalar_int = 500,
convergence_tol: scalar_float = 1e-6,
regularization_strength: scalar_float = 1e-4,
entropy_weight: scalar_float = 1e-3,
n_mosaic_points: scalar_int = 7,
normalize: bool = True,
verbose: bool = False,
) -> OrientationFitResult:
"""Optimize orientation weights to match an observed pattern.
Description
-----------
Given a fixed candidate angle support, optimize the discrete
orientation weights together with a shared Gaussian mosaic width.
The internal optimization is carried out in an unconstrained space
with a specialized Adam update on ``(weight_logits, mosaic_param)``
and a JIT-compiled :func:`jax.lax.scan` loop.
Parameters
----------
observed_pattern : Float[Array, "H W"]
Experimental RHEED pattern to match.
simulate_fn : Callable[[scalar_float], Float[Array, "H W"]]
Forward simulation function mapping ``phi_deg`` to a detector
pattern.
candidate_angles_deg : Float[Array, "M"]
Candidate orientation angles in degrees.
mask : Optional[Float[Array, "H W"]], optional
Pixel mask for loss computation. Default: all pixels.
initial_weights : Optional[Float[Array, "M"]], optional
Initial weight guess. Default: uniform weights.
learning_rate : scalar_float, optional
Adam learning rate. Default: 0.1
n_iterations : scalar_int, optional
Maximum optimization steps. Default: 500
convergence_tol : scalar_float, optional
Stop when the optimizer update norm falls below this tolerance.
Default: 1e-6
regularization_strength : scalar_float, optional
L2 regularization on weight deviations from uniform.
Default: 1e-4
entropy_weight : scalar_float, optional
Entropy regularization. Default: 1e-3
n_mosaic_points : scalar_int, optional
Quadrature points for mosaic broadening. Default: 7
normalize : bool, optional
If True, normalize both observed and simulated patterns before
comparison. Default: True
verbose : bool, optional
Print the final fitted weights and loss. Default: False
Returns
-------
result : OrientationFitResult
Fitting result with recovered distribution and diagnostics.
"""
candidate_angles: Float[Array, "M"] = jnp.atleast_1d(
jnp.asarray(candidate_angles_deg, dtype=jnp.float64)
)
n_angles: int = candidate_angles.shape[0]
initial_weight_array: Float[Array, "M"]
if initial_weights is None:
initial_weight_array = jnp.ones(n_angles, dtype=jnp.float64) / n_angles
else:
initial_weight_array = _normalize_probability_weights(initial_weights)
initial_params: _OrientationWeightParameters = (
_OrientationWeightParameters(
weight_logits=jnp.log(initial_weight_array + _PROBABILITY_EPS),
mosaic_param=jnp.asarray(
_DEFAULT_MOSAIC_PARAM,
dtype=jnp.float64,
),
)
)
def objective_fn(
params: _OrientationWeightParameters,
) -> Float[Array, ""]:
return orientation_loss(
distribution=_distribution_from_parameters(
candidate_angles,
params,
),
simulate_fn=simulate_fn,
observed_pattern=observed_pattern,
mask=mask,
normalize=normalize,
regularization_strength=regularization_strength,
entropy_weight=entropy_weight,
n_mosaic_points=n_mosaic_points,
)
n_iterations_int: int = int(n_iterations)
def _run_optimizer( # noqa: PLR0915
start_params: _OrientationWeightParameters,
) -> _OrientationOptimizerResult:
objective_and_grad_fn: _OrientationObjectiveGradFn = (
jax.value_and_grad(objective_fn)
)
initial_optimizer_state: _OrientationAdamState = _OrientationAdamState(
first_logits=jnp.zeros_like(start_params.weight_logits),
second_logits=jnp.zeros_like(start_params.weight_logits),
first_mosaic=jnp.zeros_like(start_params.mosaic_param),
second_mosaic=jnp.zeros_like(start_params.mosaic_param),
)
initial_loss: Float[Array, ""] = objective_fn(start_params)
def _step(
carry: _OrientationOptimizerCarry,
step_index: Int[Array, ""],
) -> _OrientationOptimizerStep:
params: _OrientationWeightParameters
optimizer_state: _OrientationAdamState
converged_flag: Bool[Array, ""]
steps_taken: scalar_int
last_loss: Float[Array, ""]
(
params,
optimizer_state,
converged_flag,
steps_taken,
last_loss,
) = carry
def _frozen_step(
frozen_carry: _OrientationOptimizerCarry,
) -> _OrientationOptimizerStep:
return frozen_carry, frozen_carry[-1]
def _active_step(
active_carry: _OrientationOptimizerCarry,
) -> _OrientationOptimizerStep:
active_params: _OrientationWeightParameters
active_optimizer_state: _OrientationAdamState
_active_converged_flag: Bool[Array, ""]
active_steps: scalar_int
_active_last_loss: Float[Array, ""]
(
active_params,
active_optimizer_state,
_active_converged_flag,
active_steps,
_active_last_loss,
) = active_carry
objective_value: Float[Array, ""]
gradients: _OrientationWeightParameters
objective_value, gradients = objective_and_grad_fn(
active_params
)
gradient_norm: Float[Array, ""] = _parameter_norm(gradients)
def _converged_on_gradient(
operand: Tuple[
_OrientationWeightParameters,
_OrientationAdamState,
Float[Array, ""],
],
) -> Tuple[
_OrientationWeightParameters,
_OrientationAdamState,
Float[Array, ""],
Bool[Array, ""],
]:
grad_params: _OrientationWeightParameters
grad_optimizer_state: _OrientationAdamState
grad_loss: Float[Array, ""]
grad_params, grad_optimizer_state, grad_loss = operand
return (
grad_params,
grad_optimizer_state,
grad_loss,
jnp.asarray(True),
)
def _take_update_step(
operand: Tuple[
_OrientationWeightParameters,
_OrientationAdamState,
Float[Array, ""],
],
) -> Tuple[
_OrientationWeightParameters,
_OrientationAdamState,
Float[Array, ""],
Bool[Array, ""],
]:
update_params: _OrientationWeightParameters
update_optimizer_state: _OrientationAdamState
_update_loss: Float[Array, ""]
update_params, update_optimizer_state, _update_loss = (
operand
)
next_params: _OrientationWeightParameters
next_optimizer_state: _OrientationAdamState
step_norm: Float[Array, ""]
next_params, next_optimizer_state, step_norm = (
_adam_update(
params=update_params,
gradients=gradients,
optimizer_state=update_optimizer_state,
learning_rate=learning_rate,
iteration=step_index + 1,
)
)
next_loss: Float[Array, ""] = objective_fn(next_params)
next_converged: Bool[Array, ""] = (
step_norm <= convergence_tol
)
return (
next_params,
next_optimizer_state,
next_loss,
next_converged,
)
next_params: _OrientationWeightParameters
next_optimizer_state: _OrientationAdamState
recorded_loss: Float[Array, ""]
next_converged_flag: Bool[Array, ""]
cond_result: Tuple[
_OrientationWeightParameters,
_OrientationAdamState,
Float[Array, ""],
Bool[Array, ""],
]
cond_result = jax.lax.cond(
gradient_norm <= convergence_tol,
_converged_on_gradient,
_take_update_step,
(active_params, active_optimizer_state, objective_value),
)
(
next_params,
next_optimizer_state,
recorded_loss,
next_converged_flag,
) = cond_result
next_steps: scalar_int = active_steps + jnp.asarray(
1, dtype=jnp.int32
)
next_carry: _OrientationOptimizerCarry = (
next_params,
next_optimizer_state,
next_converged_flag,
next_steps,
recorded_loss,
)
return next_carry, recorded_loss
return jax.lax.cond(
converged_flag,
_frozen_step,
_active_step,
carry,
)
final_carry: _OrientationOptimizerCarry
loss_history_full: Float[Array, "N_steps_full"]
final_carry, loss_history_full = jax.lax.scan(
_step,
(
start_params,
initial_optimizer_state,
jnp.asarray(False),
jnp.asarray(0, dtype=jnp.int32),
initial_loss,
),
jnp.arange(n_iterations_int, dtype=jnp.int32),
)
final_params: _OrientationWeightParameters
_final_optimizer_state: _OrientationAdamState
converged_flag: Bool[Array, ""]
steps_taken: scalar_int
final_loss: Float[Array, ""]
(
final_params,
_final_optimizer_state,
converged_flag,
steps_taken,
final_loss,
) = final_carry
return (
final_params,
loss_history_full,
steps_taken,
final_loss,
converged_flag,
)
fitted_params: _OrientationWeightParameters
full_loss_history: Float[Array, "N_steps_full"]
completed_steps: scalar_int
final_loss: Float[Array, ""]
converged_flag: Bool[Array, ""]
(
fitted_params,
full_loss_history,
completed_steps,
final_loss,
converged_flag,
) = jax.jit(_run_optimizer)(initial_params)
completed_steps_int: int = int(completed_steps)
loss_history: Float[Array, "N_steps"] = full_loss_history[
:completed_steps_int
]
fitted_distribution: OrientationDistribution = (
_distribution_from_parameters(
candidate_angles,
fitted_params,
)
)
simulated_pattern: float_jax_image = _simulate_distribution_pattern(
simulate_fn,
fitted_distribution,
n_mosaic_points=n_mosaic_points,
)
residual_pattern: float_jax_image = (
jnp.asarray(observed_pattern, dtype=jnp.float64) - simulated_pattern
)
if completed_steps_int == 0:
final_loss = objective_fn(initial_params)
if verbose:
fitted_weights: list[str] = [
f"{float(weight):.3f}"
for weight in fitted_distribution.discrete_weights
]
print(
"Orientation fit:",
f"loss={float(final_loss):.6f}",
f"weights={fitted_weights}",
f"mosaic_fwhm_deg={float(fitted_distribution.mosaic_fwhm_deg):.4f}",
)
return OrientationFitResult(
fitted_distribution=fitted_distribution,
final_loss=final_loss,
loss_history=loss_history,
converged=bool(converged_flag),
n_iterations=completed_steps_int,
residual_pattern=residual_pattern,
)
[docs]
@jaxtyped(typechecker=beartype)
def estimate_weight_uncertainty(
result: OrientationFitResult,
simulate_fn: Callable[[scalar_float], float_jax_image],
noise_variance: scalar_float = 1.0,
mask: Optional[float_jax_image] = None,
normalize: bool = True,
n_mosaic_points: scalar_int = 7,
) -> Float[Array, "M"]:
"""Estimate 1σ uncertainties on the fitted discrete weights.
Description
-----------
Fisher information is computed in the unconstrained softmax-logit
parameterization and then propagated to the physical weights with the
Jacobian of the softmax map.
Parameters
----------
result : OrientationFitResult
Fitting result from :func:`fit_orientation_weights`.
simulate_fn : Callable[[scalar_float], Float[Array, "H W"]]
Forward simulation function.
noise_variance : scalar_float, optional
Assumed per-pixel Gaussian noise variance. Default: 1.0
mask : Optional[Float[Array, "H W"]], optional
Optional detector mask. Default: all pixels.
normalize : bool, optional
If True, compute uncertainties on normalized detector patterns.
Default: True
n_mosaic_points : scalar_int, optional
Quadrature points for mosaic broadening. Default: 7
Returns
-------
uncertainties : Float[Array, "M"]
Approximate 1σ uncertainty for each discrete orientation weight.
Notes
-----
The returned uncertainties are conditional on the fitted orientation
angles and mosaic broadening being held fixed.
"""
fitted_distribution: OrientationDistribution = _sanitize_distribution(
result.fitted_distribution
)
fisher_logits: Float[Array, "M M"] = compute_fisher_information(
simulate_fn,
fitted_distribution,
noise_variance=noise_variance,
mask=mask,
normalize=normalize,
n_mosaic_points=n_mosaic_points,
)
n_weights: int = fisher_logits.shape[0]
fisher_regularized: Float[Array, "M M"] = fisher_logits + (
_FISHER_REGULARIZATION * jnp.eye(n_weights, dtype=fisher_logits.dtype)
)
covariance_logits: Float[Array, "M M"] = jnp.linalg.inv(fisher_regularized)
weights: Float[Array, "M"] = fitted_distribution.discrete_weights
softmax_jacobian: Float[Array, "M M"] = (
jnp.diag(weights) - weights[:, None] * weights[None, :]
)
covariance_weights: Float[Array, "M M"] = (
softmax_jacobian @ covariance_logits @ softmax_jacobian.T
)
variances: Float[Array, "M"] = jnp.maximum(
jnp.diag(covariance_weights),
0.0,
)
return jnp.sqrt(variances)
__all__: list[str] = [
"OrientationFitResult",
"compute_fisher_information",
"estimate_weight_uncertainty",
"fit_orientation_weights",
"orientation_loss",
]