Source code for optixstuff.optical_elements
"""Optical element abstractions (throughput, filters, field stops)."""
import abc
from typing import final
import equinox as eqx
import interpax
import jax.numpy as jnp
from jax.typing import ArrayLike
from jaxtyping import Array
[docs]
class AbstractOpticalElement(eqx.Module):
"""Abstract interface for an optical element in the beam path.
Elements reduce photon flux via wavelength-dependent throughput.
The ETC calls get_throughput() for scalar efficiency calculations.
The simulator calls apply() to attenuate 2D photon arrays.
Both methods are abstract: use AbstractUniformElement for elements
with spatially uniform throughput, which provides a default apply().
"""
[docs]
@abc.abstractmethod
def get_throughput(self, wavelength_nm: ArrayLike) -> ArrayLike:
"""Fractional throughput at a given wavelength.
Args:
wavelength_nm: Wavelength in nanometres.
Returns:
Scalar throughput in [0, 1].
"""
...
[docs]
@abc.abstractmethod
def apply(self, arr: Array, wavelength_nm: ArrayLike) -> Array:
"""Apply this element to a 2D photon array.
Args:
arr: Input photon rate array [ph/s/pixel].
wavelength_nm: Wavelength in nanometres.
Returns:
Attenuated photon rate array, same shape as arr.
"""
...
[docs]
@final
class ConstantThroughput(AbstractUniformElement):
"""An optical element with wavelength-independent throughput.
Useful for modeling simple attenuators, beamsplitters, or as a
placeholder during instrument design studies.
"""
throughput: float
name: str = eqx.field(default="element", static=True)
[docs]
def get_throughput(self, wavelength_nm: ArrayLike) -> ArrayLike:
"""Return constant throughput, ignoring wavelength."""
return self.throughput
[docs]
def __repr__(self) -> str:
"""One-line summary of throughput value."""
return (
f"ConstantThroughput(name={self.name!r}, "
f"throughput={self.throughput:.3g})"
)
[docs]
@final
class SpectralThroughput(AbstractUniformElement):
"""Wavelength-dependent throughput defined by sampled (wavelength, throughput) pairs.
Linear interpolation between samples; throughput is zero outside
the defined wavelength range.
Represents any tabulated wavelength-dependent throughput in the
optical path: bandpass filters, dichroics, mirror reflectivity,
coating losses, ADCs, blocking filters, etc.
"""
wavelengths_nm: Array
throughputs: Array
interp: interpax.Interpolator1D
def __init__(self, wavelengths_nm: Array, throughputs: Array) -> None:
"""Create a spectral throughput element from sampled pairs."""
self.wavelengths_nm = wavelengths_nm
self.throughputs = throughputs
self.interp = interpax.Interpolator1D(
wavelengths_nm, throughputs, method="linear", extrap=jnp.array([0.0, 0.0])
)
[docs]
def get_throughput(self, wavelength_nm: ArrayLike) -> ArrayLike:
"""Interpolate throughput at the requested wavelength."""
return self.interp(wavelength_nm)
[docs]
def __repr__(self) -> str:
"""One-line summary of throughput-table extent and sample count."""
n = int(self.wavelengths_nm.shape[0])
wl_min = float(self.wavelengths_nm.min())
wl_max = float(self.wavelengths_nm.max())
t_min = float(self.throughputs.min())
t_max = float(self.throughputs.max())
return (
f"SpectralThroughput(wl={wl_min:.0f}-{wl_max:.0f} nm, "
f"n={n}, throughput={t_min:.3g}-{t_max:.3g})"
)