diff --git a/CLAUDE.md b/CLAUDE.md index ede0ab3..0571997 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -29,6 +29,7 @@ SHINE/ ├── shine/ # Main Python package │ ├── __init__.py │ ├── config.py # Base inference configuration +│ ├── prior_utils.py # Shared prior-parsing (config → NumPyro sample sites) │ ├── inference.py # Inference engine (MAP, NUTS) │ └── euclid/ # Euclid VIS instrument backend │ ├── config.py # Euclid-specific configuration @@ -55,7 +56,8 @@ SHINE/ | Module | Status | Purpose | |--------|--------|---------| -| `shine.config` | Implemented | Inference configuration (MAP, NUTS, VI settings) | +| `shine.config` | Implemented | Configuration schema (galaxy model, inference, distributions with `center: catalog`) | +| `shine.prior_utils` | Implemented | Shared prior-parsing: converts `DistributionConfig` → NumPyro sample sites | | `shine.inference` | Implemented | Inference engine (MAP optimization, NUTS via NumPyro) | | `shine.euclid` | Implemented | Euclid VIS instrument backend: data loading, scene model, diagnostics | | `shine.scene_modelling` | Planned | Generic NumPyro generative model definitions | @@ -67,7 +69,7 @@ SHINE/ The first instrument backend, providing end-to-end shear inference on Euclid VIS quadrant-level data: -- **`config.py`** — Pydantic configuration: data paths, source selection (SNR, `det_quality_flag`, size filtering), prior distributions, multi-tier stamp sizes +- **`config.py`** — Pydantic configuration: data paths, source selection (SNR, `det_quality_flag`, size filtering), galaxy model specification via shared `GalaxyConfig` (supports `center: catalog` priors), multi-tier stamp sizes - **`data_loader.py`** — Reads quadrant FITS files (SCI/RMS/FLG), PSF grids with bilinear interpolation, background maps, MER catalogs; computes per-source WCS positions, Jacobians, PSF stamps, and visibility - **`scene.py`** — NumPyro generative model: renders Sersic galaxies convolved with spatially-varying PSFs via JAX-GalSim; multi-tier stamp sizes (64/128/256 px) with separate `vmap` per tier; standalone `render_model_images()` for post-inference visualization - **`plots.py`** — 3-panel diagnostic figures (observed | model | chi residual) with configurable masking @@ -120,6 +122,22 @@ Future testing should also include: SHINE uses GalSim-compatible YAML configuration with a probabilistic extension: any parameter defined as a distribution (e.g., `type: Normal`) becomes a **latent variable** for inference rather than a fixed simulation value. See `DESIGN.md` Section 6.1 for config examples. +Both the generic `SceneBuilder` and the Euclid `MultiExposureScene` read their probabilistic model from the same `GalaxyConfig` schema in the YAML `gal:` section. The shared `parse_prior()` function in `shine.prior_utils` converts each config entry into a NumPyro sample site. For catalog-centered priors (where the location parameter comes from per-source catalog data), use `center: catalog`: + +```yaml +gal: + type: Exponential + flux: {type: LogNormal, center: catalog, sigma: 0.5} # median from catalog + half_light_radius: {type: LogNormal, center: catalog, sigma: 0.3} + shear: + g1: {type: Normal, mean: 0.0, sigma: 0.05} + g2: {type: Normal, mean: 0.0, sigma: 0.05} + position: + type: Offset + dx: {type: Normal, mean: 0.0, sigma: 0.05} + dy: {type: Normal, mean: 0.0, sigma: 0.05} +``` + ## Development Roadmap 1. **Phase 1:** Prototype with simple parametric models (Sersic) and constant PSF diff --git a/configs/euclid_vis.yaml b/configs/euclid_vis.yaml index c22ba74..4a640a4 100644 --- a/configs/euclid_vis.yaml +++ b/configs/euclid_vis.yaml @@ -2,6 +2,11 @@ # # Points to the bundled test data in data/EUC_VIS_SWL/. # Selects the 600 brightest sources and runs MAP inference for quick testing. +# +# The `gal:` section specifies the probabilistic model: galaxy profile type, +# which parameters are latent variables (distributions) vs fixed, and what +# priors are used. Parameters with `center: catalog` use per-source catalog +# values as the distribution location at runtime. data: exposure_paths: @@ -26,12 +31,31 @@ sources: det_quality_exclude_mask: 0x78C max_sources: 600 -priors: - shear_prior_sigma: 0.05 - flux_prior_log_sigma: 0.5 - hlr_prior_log_sigma: 0.3 - ellipticity_prior_sigma: 0.3 - position_prior_sigma: 0.05 +# Galaxy model specification — the probabilistic model is explicit here. +# Each parameter is either a fixed value or a distribution (= latent variable). +gal: + type: Exponential + + shear: + type: G1G2 + g1: {type: Normal, mean: 0.0, sigma: 0.05} + g2: {type: Normal, mean: 0.0, sigma: 0.05} + + # Catalog-centered priors: center="catalog" means the LogNormal median + # is set to each source's catalog value at runtime. + flux: {type: LogNormal, center: catalog, sigma: 0.5} + half_light_radius: {type: LogNormal, center: catalog, sigma: 0.3} + + ellipticity: + type: E1E2 + e1: {type: Normal, mean: 0.0, sigma: 0.3} + e2: {type: Normal, mean: 0.0, sigma: 0.3} + + # Position offsets from catalog positions (in arcsec) + position: + type: Offset + dx: {type: Normal, mean: 0.0, sigma: 0.05} + dy: {type: Normal, mean: 0.0, sigma: 0.05} inference: method: map diff --git a/docs/api/config.md b/docs/api/config.md index a65e69f..c5c343b 100644 --- a/docs/api/config.md +++ b/docs/api/config.md @@ -4,8 +4,12 @@ Configuration handling with Pydantic models. Parses YAML configuration files and validates all parameters. Distribution parameters (Normal, LogNormal, Uniform) are automatically treated as latent -variables for Bayesian inference. The `InferenceConfig` supports three -inference methods (NUTS, MAP, VI) with method-specific config blocks -(`NUTSConfig`, `MAPConfig`, `VIConfig`). +variables for Bayesian inference. Distributions can use `center: "catalog"` to +resolve their location parameter from per-source catalog data at runtime. +Position priors support both `Uniform` (absolute pixel positions) and `Offset` +(small offsets from catalog positions) modes. + +The `InferenceConfig` supports three inference methods (NUTS, MAP, VI) with +method-specific config blocks (`NUTSConfig`, `MAPConfig`, `VIConfig`). ::: shine.config diff --git a/docs/api/prior_utils.md b/docs/api/prior_utils.md new file mode 100644 index 0000000..7e12fd5 --- /dev/null +++ b/docs/api/prior_utils.md @@ -0,0 +1,10 @@ +# shine.prior_utils + +Shared prior-parsing utilities for SHINE scene builders. + +Converts `DistributionConfig` entries (or fixed numeric values) into NumPyro +sample sites. Supports catalog-centered priors via the `center="catalog"` +mechanism, where the distribution location parameter comes from per-source +catalog data at runtime. + +::: shine.prior_utils diff --git a/docs/configuration.md b/docs/configuration.md index 8637a15..fd6eef2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -99,9 +99,29 @@ Supported distributions: | Type | Parameters | Description | |------|-----------|-------------| | `Normal` | `mean`, `sigma` | Gaussian prior | -| `LogNormal` | `mean`, `sigma` | Log-normal prior | +| `LogNormal` | `mean`, `sigma` | Log-normal prior (`mean` is the median) | | `Uniform` | `min`, `max` | Uniform prior | +#### Catalog-centered priors + +For instrument-specific backends (e.g. Euclid) where per-source catalog +measurements are available, any `Normal` or `LogNormal` distribution can be +centered on catalog values by using `center: catalog` instead of a fixed `mean`: + +```yaml +flux: + type: LogNormal + center: catalog # median comes from each source's catalog flux at runtime + sigma: 0.5 +``` + +When `center: catalog` is set, the `mean` field is ignored. The scene builder +resolves the location parameter from the data catalog at runtime: + +- **LogNormal**: `LogNormal(log(catalog_value_i), sigma)` — the catalog value + is the median +- **Normal**: `Normal(catalog_value_i, sigma)` — the catalog value is the mean + ### Shear Gravitational shear is defined as two components: @@ -140,8 +160,10 @@ gal: ### Position -Galaxy position priors. Values < 1 are treated as fractions of image size; -values >= 1 are absolute pixel coordinates. +Galaxy position priors support two modes: + +**Uniform (absolute positions)** — values < 1 are treated as fractions of image +size; values >= 1 are absolute pixel coordinates. ```yaml gal: @@ -153,6 +175,24 @@ gal: y_max: 0.7 ``` +**Offset (from catalog positions)** — small offsets from known catalog +positions, typically used with instrument backends where source positions +come from a detection catalog. + +```yaml +gal: + position: + type: Offset + dx: + type: Normal + mean: 0.0 + sigma: 0.05 # arcsec + dy: + type: Normal + mean: 0.0 + sigma: 0.05 +``` + ## Inference Section Controls the inference method and its settings. SHINE supports three methods: @@ -245,7 +285,9 @@ inference: | `method` | `"nuts"` / `"map"` / `"vi"` | `"nuts"` | Inference method | | `rng_seed` | int >= 0 | `0` | JAX PRNG seed | -## Complete Example +## Complete Examples + +### Level 0 — simple validation config ```yaml image: @@ -296,3 +338,50 @@ inference: learning_rate: 0.01 rng_seed: 42 ``` + +### Euclid VIS — instrument backend config + +The Euclid config uses the same `gal:` schema but with `center: catalog` +priors and `Offset` positions. The `data:` and `sources:` sections are +specific to the Euclid backend (`EuclidInferenceConfig`). + +```yaml +data: + exposure_paths: + - data/EUC_VIS_SWL/EUC_VIS_SWL-DET-..._3-4-F.fits.gz + psf_path: data/EUC_VIS_SWL/PSF_3-4-F.fits.gz + catalog_path: data/EUC_VIS_SWL/catalogue_3-4-F.fits.gz + pixel_scale: 0.1 + +sources: + min_snr: 10.0 + max_sources: 600 + +gal: + type: Exponential + shear: + type: G1G2 + g1: {type: Normal, mean: 0.0, sigma: 0.05} + g2: {type: Normal, mean: 0.0, sigma: 0.05} + flux: {type: LogNormal, center: catalog, sigma: 0.5} + half_light_radius: {type: LogNormal, center: catalog, sigma: 0.3} + ellipticity: + type: E1E2 + e1: {type: Normal, mean: 0.0, sigma: 0.3} + e2: {type: Normal, mean: 0.0, sigma: 0.3} + position: + type: Offset + dx: {type: Normal, mean: 0.0, sigma: 0.05} + dy: {type: Normal, mean: 0.0, sigma: 0.05} + +inference: + method: map + map_config: + enabled: true + num_steps: 200 + learning_rate: 0.002 + rng_seed: 42 + +galaxy_stamp_sizes: [64, 128, 256] +background: fixed +``` diff --git a/mkdocs.yml b/mkdocs.yml index 9b0f10a..b8b5f24 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -15,6 +15,7 @@ nav: - GPU-Batched Inference: validation/batched.md - API Reference: - shine.config: api/config.md + - shine.prior_utils: api/prior_utils.md - shine.scene: api/scene.md - shine.inference: api/inference.md - shine.data: api/data.md diff --git a/notebooks/euclid_vis_map.ipynb b/notebooks/euclid_vis_map.ipynb index b0f0552..0743e03 100644 --- a/notebooks/euclid_vis_map.ipynb +++ b/notebooks/euclid_vis_map.ipynb @@ -13,40 +13,10 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "import warnings\n", - "warnings.filterwarnings(\"ignore\", message=\".*complex128.*\", module=\"jax_galsim\")\n", - "\n", - "from pathlib import Path\n", - "\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "\n", - "from shine.config import InferenceConfig, MAPConfig\n", - "from shine.euclid.config import (\n", - " EuclidDataConfig,\n", - " EuclidInferenceConfig,\n", - " SourceSelectionConfig,\n", - ")\n", - "from shine.euclid.data_loader import EuclidDataLoader\n", - "from shine.euclid.scene import MultiExposureScene, render_model_images\n", - "from shine.euclid.plots import plot_exposure_comparison\n", - "from shine.inference import Inference\n", - "\n", - "%matplotlib inline\n", - "\n", - "# --- Configuration ---\n", - "DATA_DIR = Path(\"../data/EUC_VIS_SWL\")\n", - "MAX_SOURCES = 1000\n", - "MIN_SNR = 10.0\n", - "MAP_STEPS = 220\n", - "RNG_SEED = 42" - ] + "source": "import warnings\nwarnings.filterwarnings(\"ignore\", message=\".*complex128.*\", module=\"jax_galsim\")\n\nfrom pathlib import Path\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nfrom shine.config import InferenceConfig, MAPConfig\nfrom shine.euclid.config import (\n EuclidDataConfig,\n EuclidInferenceConfig,\n SourceSelectionConfig,\n)\nfrom shine.euclid.data_loader import EuclidDataLoader\nfrom shine.euclid.scene import MultiExposureScene, render_model_images\nfrom shine.euclid.plots import plot_exposure_comparison\nfrom shine.inference import Inference\n\n%matplotlib inline\n\n# --- Configuration ---\nDATA_DIR = Path(\"../data/EUC_VIS_SWL\")\nMAX_SOURCES = 1000\nMIN_SNR = 10.0\nMAP_STEPS = 220\nRNG_SEED = 42" }, { "cell_type": "markdown", @@ -512,4 +482,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/shine/config.py b/shine/config.py index faa517f..b93c3a4 100644 --- a/shine/config.py +++ b/shine/config.py @@ -8,14 +8,25 @@ class DistributionConfig(BaseModel): """Configuration for probability distributions used as priors. - Supports Normal, LogNormal, and Uniform distributions with appropriate parameters. + Supports Normal, LogNormal, and Uniform distributions with appropriate + parameters. + + When ``center`` is ``"catalog"``, the location parameter (mean for + Normal, median for LogNormal) is resolved at runtime from per-source + catalog data rather than from the static ``mean`` field. Attributes: type: Distribution type (e.g., 'Normal', 'LogNormal', 'Uniform'). mean: Mean parameter for Normal/LogNormal distributions. + For LogNormal this is the *median* (natural-space value); + internally ``log(mean)`` is used as the log-space location. + Optional when ``center="catalog"``. sigma: Standard deviation for Normal/LogNormal distributions. min: Lower bound for Uniform distributions. max: Upper bound for Uniform distributions. + center: Optional centering strategy. ``"catalog"`` means the + location parameter comes from per-source catalog values at + runtime. When set, ``mean`` is ignored. """ model_config = ConfigDict(extra="allow") @@ -25,6 +36,7 @@ class DistributionConfig(BaseModel): sigma: Optional[float] = None min: Optional[float] = None max: Optional[float] = None + center: Optional[str] = None @field_validator("sigma") @classmethod @@ -34,13 +46,42 @@ def validate_sigma_positive(cls, v: Optional[float]) -> Optional[float]: raise ValueError(f"sigma must be positive, got {v}") return v + @field_validator("center") + @classmethod + def validate_center(cls, v: Optional[str]) -> Optional[str]: + """Validate that center is either None or 'catalog'.""" + if v is not None and v != "catalog": + raise ValueError( + f"center must be 'catalog' or omitted, got '{v}'" + ) + return v + @model_validator(mode="after") def validate_distribution_params(self) -> "DistributionConfig": """Validate distribution type has required parameters.""" - if self.type == "Normal" and (self.mean is None or self.sigma is None): - raise ValueError( - "Normal distribution requires 'mean' and 'sigma' parameters" - ) + catalog_centered = self.center == "catalog" + + if self.type == "Normal": + if not catalog_centered and (self.mean is None or self.sigma is None): + raise ValueError( + "Normal distribution requires 'mean' and 'sigma' parameters " + "(or set center='catalog' to use catalog values as mean)" + ) + if catalog_centered and self.sigma is None: + raise ValueError( + "Normal distribution with center='catalog' requires 'sigma'" + ) + + if self.type == "LogNormal": + if not catalog_centered and (self.mean is None or self.sigma is None): + raise ValueError( + "LogNormal distribution requires 'mean' and 'sigma' parameters " + "(or set center='catalog' to use catalog values as median)" + ) + if catalog_centered and self.sigma is None: + raise ValueError( + "LogNormal distribution with center='catalog' requires 'sigma'" + ) if self.type == "Uniform" and (self.min is None or self.max is None): raise ValueError( @@ -214,15 +255,23 @@ class EllipticityConfig(BaseModel): class PositionConfig(BaseModel): """Configuration for galaxy position priors. - Defines the prior distribution over galaxy positions in the image. - Values less than 1 are treated as fractions of image size, values >= 1 as pixels. + Supports two modes: + + * ``type="Uniform"`` — absolute pixel positions drawn from a uniform + distribution. Values less than 1 are treated as fractions of image + size, values >= 1 as pixels. + * ``type="Offset"`` — small position offsets (e.g. from catalog + positions) specified as ``dx`` and ``dy``, each of which can be a + fixed value or a :class:`DistributionConfig`. Attributes: - type: Distribution type for positions (default 'Uniform'). - x_min: Minimum x position (fraction if < 1, pixels if >= 1). - x_max: Maximum x position (fraction if < 1, pixels if >= 1). - y_min: Minimum y position (fraction if < 1, pixels if >= 1). - y_max: Maximum y position (fraction if < 1, pixels if >= 1). + type: Position mode (``'Uniform'`` or ``'Offset'``). + x_min: Minimum x position (Uniform mode). + x_max: Maximum x position (Uniform mode). + y_min: Minimum y position (Uniform mode). + y_max: Maximum y position (Uniform mode). + dx: Position offset in x (Offset mode; fixed value or distribution). + dy: Position offset in y (Offset mode; fixed value or distribution). """ type: str = "Uniform" @@ -230,17 +279,37 @@ class PositionConfig(BaseModel): x_max: Optional[float] = None y_min: Optional[float] = None y_max: Optional[float] = None + dx: Optional[Union[float, DistributionConfig]] = None + dy: Optional[Union[float, DistributionConfig]] = None @model_validator(mode="after") - def validate_position_bounds(self) -> "PositionConfig": - """Validate that position bounds are consistent.""" - if self.x_min is not None and self.x_max is not None and self.x_min >= self.x_max: - raise ValueError( - f"x_min ({self.x_min}) must be less than x_max ({self.x_max})" - ) - if self.y_min is not None and self.y_max is not None and self.y_min >= self.y_max: + def validate_position_config(self) -> "PositionConfig": + """Validate position config based on type.""" + if self.type == "Uniform": + if ( + self.x_min is not None + and self.x_max is not None + and self.x_min >= self.x_max + ): + raise ValueError( + f"x_min ({self.x_min}) must be less than x_max ({self.x_max})" + ) + if ( + self.y_min is not None + and self.y_max is not None + and self.y_min >= self.y_max + ): + raise ValueError( + f"y_min ({self.y_min}) must be less than y_max ({self.y_max})" + ) + elif self.type == "Offset": + if self.dx is None or self.dy is None: + raise ValueError( + "Position type 'Offset' requires 'dx' and 'dy' fields" + ) + else: raise ValueError( - f"y_min ({self.y_min}) must be less than y_max ({self.y_max})" + f"Position type must be 'Uniform' or 'Offset', got '{self.type}'" ) return self diff --git a/shine/euclid/__init__.py b/shine/euclid/__init__.py index 340e037..0851399 100644 --- a/shine/euclid/__init__.py +++ b/shine/euclid/__init__.py @@ -7,14 +7,12 @@ from shine.euclid.config import ( EuclidDataConfig, EuclidInferenceConfig, - PriorConfig, SourceSelectionConfig, ) __all__ = [ "EuclidDataConfig", "SourceSelectionConfig", - "PriorConfig", "EuclidInferenceConfig", "EuclidPSFModel", "EuclidExposure", diff --git a/shine/euclid/config.py b/shine/euclid/config.py index a4b2743..8477ec4 100644 --- a/shine/euclid/config.py +++ b/shine/euclid/config.py @@ -1,14 +1,24 @@ """Pydantic configuration models for Euclid VIS shear inference. Provides structured, validated configuration for Euclid VIS data paths, -source selection criteria, prior distributions, and inference settings. +source selection criteria, galaxy model specification, and inference +settings. The galaxy model (priors, profile type) is specified via the +shared :class:`~shine.config.GalaxyConfig`, making the probabilistic +model explicit in the YAML configuration file. """ from typing import List, Literal, Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator -from shine.config import InferenceConfig +from shine.config import ( + DistributionConfig, + EllipticityConfig, + GalaxyConfig, + InferenceConfig, + PositionConfig, + ShearConfig, +) class EuclidDataConfig(BaseModel): @@ -133,67 +143,54 @@ def validate_max_sources_positive(cls, v: Optional[int]) -> Optional[int]: return v -class PriorConfig(BaseModel): - """Prior distribution parameters for Bayesian inference. +def _default_euclid_galaxy_config() -> GalaxyConfig: + """Build the default Euclid galaxy model specification. - Defines the width (sigma) of prior distributions on galaxy and - shear parameters. All priors are centered on catalog values or - zero (for shear). - - Attributes: - shear_prior_sigma: Width of the shear prior (default 0.05). - flux_prior_log_sigma: Width of the log-flux prior (default 0.5). - hlr_prior_log_sigma: Width of the log-half-light-radius prior - (default 0.3). - ellipticity_prior_sigma: Width of the ellipticity prior - (default 0.3). - position_prior_sigma: Width of the position prior in arcsec - (default 0.05). + The default model uses catalog-centered LogNormal priors for flux and + half-light radius, Normal priors for shear and ellipticity, and + Normal offset priors for position. This matches the priors that were + previously hard-coded in ``MultiExposureScene._sample_parameters()``. """ - - shear_prior_sigma: float = 0.05 - flux_prior_log_sigma: float = 0.5 - hlr_prior_log_sigma: float = 0.3 - ellipticity_prior_sigma: float = 0.3 - position_prior_sigma: float = 0.05 - - @field_validator( - "shear_prior_sigma", - "flux_prior_log_sigma", - "hlr_prior_log_sigma", - "ellipticity_prior_sigma", - "position_prior_sigma", + return GalaxyConfig( + type="Exponential", + flux=DistributionConfig(type="LogNormal", center="catalog", sigma=0.5), + half_light_radius=DistributionConfig( + type="LogNormal", center="catalog", sigma=0.3 + ), + shear=ShearConfig( + type="G1G2", + g1=DistributionConfig(type="Normal", mean=0.0, sigma=0.05), + g2=DistributionConfig(type="Normal", mean=0.0, sigma=0.05), + ), + ellipticity=EllipticityConfig( + type="E1E2", + e1=DistributionConfig(type="Normal", mean=0.0, sigma=0.3), + e2=DistributionConfig(type="Normal", mean=0.0, sigma=0.3), + ), + position=PositionConfig( + type="Offset", + dx=DistributionConfig(type="Normal", mean=0.0, sigma=0.05), + dy=DistributionConfig(type="Normal", mean=0.0, sigma=0.05), + ), ) - @classmethod - def validate_sigma_positive(cls, v: float, info) -> float: - """Validate that all prior sigma values are positive. - - Args: - v: Sigma value to validate. - info: Pydantic field validation info. - - Returns: - The validated sigma value. - - Raises: - ValueError: If sigma is not positive. - """ - if v <= 0: - raise ValueError(f"{info.field_name} must be positive, got {v}") - return v class EuclidInferenceConfig(BaseModel): """Top-level configuration for Euclid VIS shear inference. - Combines data paths, source selection, priors, and inference settings - into a single validated configuration. Reuses the base - ``shine.config.InferenceConfig`` for MCMC/VI settings. + Combines data paths, source selection, galaxy model specification, + and inference settings into a single validated configuration. The + galaxy model (profile type, priors) is specified via the shared + :class:`~shine.config.GalaxyConfig`, making the probabilistic model + explicit in the YAML file. Attributes: data: Euclid data file paths and pixel settings. sources: Source selection and filtering criteria. - priors: Prior distribution parameters. + gal: Galaxy model specification (profile type, priors). + Defaults match the previously hard-coded Euclid priors: + catalog-centered LogNormal for flux/hlr, Normal for + shear/ellipticity, Normal offsets for position. inference: Base SHINE inference configuration (NUTS/MAP/VI). galaxy_stamp_sizes: Available rendering stamp tiers in pixels, sorted ascending (default ``[64, 128, 256]``). Each source @@ -208,7 +205,7 @@ class EuclidInferenceConfig(BaseModel): data: EuclidDataConfig sources: SourceSelectionConfig = SourceSelectionConfig() - priors: PriorConfig = PriorConfig() + gal: GalaxyConfig = Field(default_factory=_default_euclid_galaxy_config) inference: InferenceConfig = InferenceConfig() galaxy_stamp_sizes: List[int] = [64, 128, 256] background: Literal["fit", "median", "fixed"] = "median" diff --git a/shine/euclid/scene.py b/shine/euclid/scene.py index a3c33ed..293e5ee 100644 --- a/shine/euclid/scene.py +++ b/shine/euclid/scene.py @@ -23,6 +23,7 @@ import numpyro.distributions as dist from shine.euclid.config import EuclidInferenceConfig +from shine.prior_utils import parse_prior logger = logging.getLogger(__name__) @@ -365,49 +366,47 @@ def _prepare_tier_indices(self, label: str) -> tuple[list[int], float, list[jnp. return stamp_sizes, pixel_scale, tier_indices def _sample_parameters(self) -> tuple: - """Sample global shear and per-source parameters. + """Sample global shear and per-source parameters from config. + + Prior distributions are read from ``self.config.gal`` (a + :class:`~shine.config.GalaxyConfig`). Parameters with + ``center="catalog"`` use the per-source catalog values stored + in ``self.data`` as the distribution location. Returns: Tuple ``(g1, g2, flux, hlr, e1, e2, dx, dy)``. """ - priors = self.config.priors + gal_cfg = self.config.gal data = self.data - g1 = numpyro.sample( - "g1", dist.Normal(0.0, priors.shear_prior_sigma) - ) - g2 = numpyro.sample( - "g2", dist.Normal(0.0, priors.shear_prior_sigma) - ) + # Global shear + g1 = parse_prior("g1", gal_cfg.shear.g1) + g2 = parse_prior("g2", gal_cfg.shear.g2) with numpyro.plate("sources", data.n_sources): - flux = numpyro.sample( - "flux", - dist.LogNormal( - jnp.log(data.catalog_flux_adu), - priors.flux_prior_log_sigma, - ), - ) - hlr = numpyro.sample( - "hlr", - dist.LogNormal( - jnp.log(data.catalog_hlr_arcsec), - priors.hlr_prior_log_sigma, - ), + flux = parse_prior( + "flux", gal_cfg.flux, + catalog_values=data.catalog_flux_adu, ) - e1 = numpyro.sample( - "e1", dist.Normal(0.0, priors.ellipticity_prior_sigma) - ) - e2 = numpyro.sample( - "e2", dist.Normal(0.0, priors.ellipticity_prior_sigma) - ) - dx = numpyro.sample( - "dx", dist.Normal(0.0, priors.position_prior_sigma) - ) - dy = numpyro.sample( - "dy", dist.Normal(0.0, priors.position_prior_sigma) + hlr = parse_prior( + "hlr", gal_cfg.half_light_radius, + catalog_values=data.catalog_hlr_arcsec, ) + # Intrinsic ellipticity + e1 = 0.0 + e2 = 0.0 + if gal_cfg.ellipticity is not None: + e1 = parse_prior("e1", gal_cfg.ellipticity.e1) + e2 = parse_prior("e2", gal_cfg.ellipticity.e2) + + # Position offsets + dx = 0.0 + dy = 0.0 + if gal_cfg.position is not None and gal_cfg.position.type == "Offset": + dx = parse_prior("dx", gal_cfg.position.dx) + dy = parse_prior("dy", gal_cfg.position.dy) + return g1, g2, flux, hlr, e1, e2, dx, dy def build_model(self) -> Callable: diff --git a/shine/prior_utils.py b/shine/prior_utils.py new file mode 100644 index 0000000..d5ca5bf --- /dev/null +++ b/shine/prior_utils.py @@ -0,0 +1,75 @@ +"""Shared prior-parsing utilities for SHINE scene builders. + +Converts :class:`~shine.config.DistributionConfig` entries (or fixed +numeric values) into NumPyro sample sites. Supports catalog-centered +priors via the ``center="catalog"`` mechanism. +""" + +from typing import Optional, Union + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist + +from shine.config import DistributionConfig + + +def parse_prior( + name: str, + param_config: Union[float, int, DistributionConfig], + catalog_values: Optional[jnp.ndarray] = None, +) -> Union[float, jnp.ndarray]: + """Create a NumPyro sample site from a config entry, or return a fixed value. + + Args: + name: Parameter name for the NumPyro sample site. + param_config: Either a fixed numeric value or a + :class:`DistributionConfig` describing the prior distribution. + catalog_values: Per-source catalog values used as the location + parameter when ``param_config.center == "catalog"``. Required + when catalog-centered priors are used; ignored otherwise. + + Returns: + Sampled value(s) from the distribution, or the fixed value. + + Raises: + ValueError: If the distribution type is not recognized, or if + ``center="catalog"`` is used but *catalog_values* is ``None``. + """ + if isinstance(param_config, (float, int)): + return float(param_config) + + catalog_centered = getattr(param_config, "center", None) == "catalog" + + if catalog_centered and catalog_values is None: + raise ValueError( + f"Parameter '{name}' has center='catalog' but no catalog_values " + f"were provided" + ) + + if param_config.type == "Normal": + if catalog_centered: + return numpyro.sample( + name, dist.Normal(catalog_values, param_config.sigma) + ) + return numpyro.sample( + name, dist.Normal(param_config.mean, param_config.sigma) + ) + + if param_config.type == "LogNormal": + if catalog_centered: + return numpyro.sample( + name, + dist.LogNormal(jnp.log(catalog_values), param_config.sigma), + ) + return numpyro.sample( + name, + dist.LogNormal(jnp.log(param_config.mean), param_config.sigma), + ) + + if param_config.type == "Uniform": + return numpyro.sample( + name, dist.Uniform(param_config.min, param_config.max) + ) + + raise ValueError(f"Unknown distribution type: '{param_config.type}'") diff --git a/shine/scene.py b/shine/scene.py index 1d6840e..2decbe4 100644 --- a/shine/scene.py +++ b/shine/scene.py @@ -8,6 +8,7 @@ from shine import galaxy_utils from shine.config import DistributionConfig, ShineConfig +from shine.prior_utils import parse_prior # Default position prior bounds as fraction of image size _DEFAULT_POS_MIN_FRAC = 0.3 @@ -33,11 +34,14 @@ def __init__(self, config: ShineConfig) -> None: """ self.config = config + @staticmethod def _parse_prior( - self, name: str, param_config: Union[float, int, DistributionConfig] + name: str, param_config: Union[float, int, DistributionConfig] ) -> float: """Create a NumPyro sample site from config, or return a fixed value. + Thin wrapper around :func:`shine.prior_utils.parse_prior`. + Args: name: Parameter name for NumPyro sampling. param_config: Either a fixed numeric value or a DistributionConfig. @@ -48,22 +52,7 @@ def _parse_prior( Raises: ValueError: If the distribution type is not recognized. """ - if isinstance(param_config, (float, int)): - return float(param_config) - - if param_config.type == "Normal": - return numpyro.sample( - name, dist.Normal(param_config.mean, param_config.sigma) - ) - if param_config.type == "LogNormal": - return numpyro.sample( - name, dist.LogNormal(jnp.log(param_config.mean), param_config.sigma) - ) - if param_config.type == "Uniform": - return numpyro.sample( - name, dist.Uniform(param_config.min, param_config.max) - ) - raise ValueError(f"Unknown distribution type: '{param_config.type}'") + return parse_prior(name, param_config) @staticmethod def _resolve_bound( diff --git a/tests/test_config.py b/tests/test_config.py index 81d8dc0..e661988 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -17,6 +17,7 @@ NoiseConfig, NUTSConfig, PSFConfig, + PositionConfig, ShearConfig, ShineConfig, VIConfig, @@ -406,3 +407,75 @@ def test_load_config_with_distributions(self): finally: Path(tmp_path).unlink() + +class TestDistributionConfigCenter: + """Test DistributionConfig catalog-centered priors.""" + + def test_lognormal_center_catalog(self): + """LogNormal with center='catalog' is valid without mean.""" + cfg = DistributionConfig(type="LogNormal", center="catalog", sigma=0.5) + assert cfg.center == "catalog" + assert cfg.sigma == 0.5 + + def test_normal_center_catalog(self): + """Normal with center='catalog' is valid without mean.""" + cfg = DistributionConfig(type="Normal", center="catalog", sigma=0.1) + assert cfg.center == "catalog" + + def test_invalid_center_value(self): + """Invalid center value raises error.""" + with pytest.raises(ValueError, match="center must be 'catalog'"): + DistributionConfig(type="Normal", center="data", sigma=0.1) + + def test_catalog_center_without_sigma_raises(self): + """center='catalog' without sigma raises error.""" + with pytest.raises(ValueError, match="requires 'sigma'"): + DistributionConfig(type="LogNormal", center="catalog") + + def test_normal_without_center_still_requires_mean(self): + """Normal without center still requires mean.""" + with pytest.raises(ValueError, match="requires 'mean' and 'sigma'"): + DistributionConfig(type="Normal", sigma=0.1) + + +class TestPositionConfigOffset: + """Test PositionConfig with Offset type.""" + + def test_offset_with_distributions(self): + """Offset position with distributions is valid.""" + dx = DistributionConfig(type="Normal", mean=0.0, sigma=0.05) + dy = DistributionConfig(type="Normal", mean=0.0, sigma=0.05) + cfg = PositionConfig(type="Offset", dx=dx, dy=dy) + assert cfg.type == "Offset" + assert isinstance(cfg.dx, DistributionConfig) + assert isinstance(cfg.dy, DistributionConfig) + + def test_offset_with_fixed_values(self): + """Offset position with fixed values is valid.""" + cfg = PositionConfig(type="Offset", dx=0.0, dy=0.0) + assert cfg.dx == 0.0 + assert cfg.dy == 0.0 + + def test_offset_missing_dx_raises(self): + """Offset without dx raises error.""" + with pytest.raises(ValueError, match="requires 'dx' and 'dy'"): + PositionConfig(type="Offset", dy=0.0) + + def test_offset_missing_dy_raises(self): + """Offset without dy raises error.""" + with pytest.raises(ValueError, match="requires 'dx' and 'dy'"): + PositionConfig(type="Offset", dx=0.0) + + def test_invalid_type_raises(self): + """Invalid position type raises error.""" + with pytest.raises(ValueError, match="must be 'Uniform' or 'Offset'"): + PositionConfig(type="Grid") + + def test_uniform_still_works(self): + """Uniform position type still works as before.""" + cfg = PositionConfig( + type="Uniform", x_min=10.0, x_max=20.0, y_min=10.0, y_max=20.0 + ) + assert cfg.type == "Uniform" + assert cfg.x_min == 10.0 + diff --git a/tests/test_prior_utils.py b/tests/test_prior_utils.py new file mode 100644 index 0000000..6939916 --- /dev/null +++ b/tests/test_prior_utils.py @@ -0,0 +1,118 @@ +"""Tests for shine.prior_utils module.""" + +import jax.numpy as jnp +import numpyro +import numpyro.handlers as handlers +import pytest +from jax import random + +from shine.config import DistributionConfig +from shine.prior_utils import parse_prior + + +class TestParsepriorFixedValues: + """Test parse_prior with fixed numeric values.""" + + def test_float_passthrough(self): + """Fixed float values are returned directly.""" + assert parse_prior("x", 1.5) == 1.5 + + def test_int_passthrough(self): + """Fixed int values are returned as float.""" + result = parse_prior("x", 3) + assert result == 3.0 + assert isinstance(result, float) + + +class TestParsePriorDistributions: + """Test parse_prior with standard distribution configs.""" + + def test_normal_distribution(self): + """Normal distribution creates a sample site with correct params.""" + cfg = DistributionConfig(type="Normal", mean=1.0, sigma=0.5) + + def model(): + return parse_prior("x", cfg) + + rng = random.PRNGKey(0) + trace = handlers.trace(handlers.seed(model, rng)).get_trace() + assert "x" in trace + assert trace["x"]["type"] == "sample" + + def test_lognormal_distribution(self): + """LogNormal distribution creates a sample site.""" + cfg = DistributionConfig(type="LogNormal", mean=100.0, sigma=0.5) + + def model(): + return parse_prior("x", cfg) + + rng = random.PRNGKey(0) + trace = handlers.trace(handlers.seed(model, rng)).get_trace() + assert "x" in trace + # LogNormal samples are always positive + assert trace["x"]["value"] > 0 + + def test_uniform_distribution(self): + """Uniform distribution creates a sample site.""" + cfg = DistributionConfig(type="Uniform", min=0.0, max=10.0) + + def model(): + return parse_prior("x", cfg) + + rng = random.PRNGKey(0) + trace = handlers.trace(handlers.seed(model, rng)).get_trace() + assert "x" in trace + val = float(trace["x"]["value"]) + assert 0.0 <= val <= 10.0 + + def test_unknown_distribution_raises(self): + """Unknown distribution type raises ValueError.""" + cfg = DistributionConfig.model_construct(type="Cauchy", sigma=1.0) + with pytest.raises(ValueError, match="Unknown distribution type"): + parse_prior("x", cfg) + + +class TestParsePriorCatalogCentered: + """Test parse_prior with center='catalog' priors.""" + + def test_lognormal_catalog_centered(self): + """LogNormal with center='catalog' uses catalog values as median.""" + cfg = DistributionConfig(type="LogNormal", center="catalog", sigma=0.5) + catalog = jnp.array([100.0, 200.0, 300.0]) + + def model(): + with numpyro.plate("sources", 3): + return parse_prior("flux", cfg, catalog_values=catalog) + + rng = random.PRNGKey(0) + trace = handlers.trace(handlers.seed(model, rng)).get_trace() + assert "flux" in trace + assert trace["flux"]["value"].shape == (3,) + # All samples should be positive (LogNormal) + assert jnp.all(trace["flux"]["value"] > 0) + + def test_normal_catalog_centered(self): + """Normal with center='catalog' uses catalog values as mean.""" + cfg = DistributionConfig(type="Normal", center="catalog", sigma=0.1) + catalog = jnp.array([1.0, 2.0, 3.0]) + + def model(): + with numpyro.plate("sources", 3): + return parse_prior("pos", cfg, catalog_values=catalog) + + rng = random.PRNGKey(0) + trace = handlers.trace(handlers.seed(model, rng)).get_trace() + assert "pos" in trace + assert trace["pos"]["value"].shape == (3,) + + def test_catalog_centered_without_values_raises(self): + """center='catalog' without catalog_values raises ValueError.""" + cfg = DistributionConfig(type="LogNormal", center="catalog", sigma=0.5) + with pytest.raises(ValueError, match="no catalog_values"): + parse_prior("flux", cfg) + + def test_catalog_centered_none_values_raises(self): + """center='catalog' with catalog_values=None raises ValueError.""" + cfg = DistributionConfig(type="LogNormal", center="catalog", sigma=0.5) + with pytest.raises(ValueError, match="no catalog_values"): + parse_prior("flux", cfg, catalog_values=None)