Source code for optixstuff.disperser
"""Disperser hardware descriptors for integral field spectrographs.
optixstuff owns the descriptor: the interface plus a cheap closed-form
scalar/ETC face. The heavy render logic (building the forward operator) lives
in coronachrome. This mirrors the coronagraph split, so jaxedith and yield
tools can read IFS hardware info without importing the render engine.
"""
import abc
import equinox as eqx
import jax.numpy as jnp
from jax import Array
[docs]
class AbstractDisperser(eqx.Module):
"""Interface for a dispersing IFS element (lenslet array, slicer, MSA).
Only the scalar/ETC face is defined here. Render geometry lives in
coronachrome and dispatches on the concrete descriptor type.
"""
[docs]
@abc.abstractmethod
def spectral_resolution(self, wavelength_nm):
"""Resolving power R = lambda / dlambda at the given wavelength."""
[docs]
@abc.abstractmethod
def spectral_sampling(self):
"""Detector pixels per resolution element."""
[docs]
@abc.abstractmethod
def n_pix_spread(self, wavelength_min_nm, wavelength_max_nm):
"""Detector pixels a single spaxel spectrum spans across a band."""
[docs]
@abc.abstractmethod
def throughput(self, wavelength_nm):
"""Disperser optical throughput in [0, 1] at the given wavelength."""
[docs]
def _polyval_deriv(coeffs, x):
"""Evaluate the derivative of a descending-order polynomial at x."""
n = coeffs.shape[0]
if n <= 1:
return jnp.zeros_like(jnp.asarray(x, dtype=float))
powers = jnp.arange(n - 1, 0, -1)
return jnp.polyval(coeffs[:-1] * powers, x)
[docs]
class LensletDisperser(AbstractDisperser):
"""Lenslet-array IFS disperser (CRISPY heritage).
Config only. The render geometry (IR build) is performed by coronachrome,
which reads these fields. Scalar/ETC methods derive from
``dispersion_coeffs`` + ``pix_per_reselt`` so the dispersion model is the
single source of truth.
"""
pitch_m: float
pixsize_m: float
angle_rad: float
lam_ref_nm: float
pix_per_reselt: float
dispersion_coeffs: Array = eqx.field(converter=jnp.asarray)
psflet_params: Array = eqx.field(converter=jnp.asarray)
grid_kind: str = eqx.field(static=True)
n_lenslets: int = eqx.field(static=True)
psflet_kind: str = eqx.field(static=True)
detector_shape: tuple[int, int] = eqx.field(static=True)
throughput_value: float = 1.0
[docs]
def _dispersion_px(self, wavelength_nm):
"""Spectral-axis detector offset [px] for the wavelength(s)."""
u = jnp.log(jnp.asarray(wavelength_nm, dtype=float) / self.lam_ref_nm)
return jnp.polyval(self.dispersion_coeffs, u)
[docs]
def spectral_resolution(self, wavelength_nm):
"""R = (local px per unit log-lambda) / pixels-per-resolution-element."""
u = jnp.log(jnp.asarray(wavelength_nm, dtype=float) / self.lam_ref_nm)
local = jnp.abs(_polyval_deriv(self.dispersion_coeffs, u))
return local / self.pix_per_reselt
[docs]
def spectral_sampling(self):
"""Detector pixels per resolution element."""
return self.pix_per_reselt
[docs]
def n_pix_spread(self, wavelength_min_nm, wavelength_max_nm):
"""Spectral trace length [px] across a band, plus a PSFlet-width margin."""
span = jnp.abs(
self._dispersion_px(wavelength_max_nm)
- self._dispersion_px(wavelength_min_nm)
)
return span + self.psflet_params[0]
[docs]
def throughput(self, wavelength_nm):
"""Constant throughput in v1."""
return self.throughput_value * jnp.ones_like(
jnp.asarray(wavelength_nm, dtype=float)
)