Skip to content

Add macOS Metal GPU backend support#79

Open
trivialTZ wants to merge 2 commits into
lenstronomy:mainfrom
trivialTZ:main
Open

Add macOS Metal GPU backend support#79
trivialTZ wants to merge 2 commits into
lenstronomy:mainfrom
trivialTZ:main

Conversation

@trivialTZ
Copy link
Copy Markdown

Summary

  • Introduce jaxtronomy/_runtime_config.py to detect macOS + Metal backend and auto-configure jax_enable_x64 (with JAXTRONOMY_ENABLE_X64 env override)
  • Replace scattered jax.config.update("jax_enable_x64", ...) calls across 15+ files with a single configure_jax_precision_for_runtime() entry point
  • Add Metal-safe EPL implementation using a real-valued omega series (_omega_real_series) to avoid complex-number ops unsupported on Metal,
    with function_real_series, derivatives_real_series, and hessian_real_series variants
  • Add spatial convolution fallback (lax.conv_general_dilated) in PixelKernelConvolution when FFT legalization fails on Metal
  • Normalize backend labels and add Metal-specific input sanitization in sampler/PSO initialization
  • Make optax import lazy in fitting_sequence.py and add macos install extra in setup.py
  • Update README with macOS/Metal install and test instructions

Test plan

  • pytest test/test_runtime_config.py — runtime config detection and env override logic
  • pytest test/test_LensModel/test_Profiles/test_epl.py — EPL profile correctness (Metal series + standard paths)
  • pytest test/test_LensModel/test_Profiles/test_epl_gpu_cpu_parity.py — GPU vs CPU numerical parity
  • pytest test/test_Sampling/test_Samplers/test_optax.py — optax sampler with lazy import
  • pytest test/test_Sampling/test_sampler.py — sampler Metal normalization
  • Full test suite on CPU-only and macOS Metal environments

trivialTZ and others added 2 commits February 26, 2026 00:03
Introduce a runtime config helper to detect macOS+Metal and set jax_enable_x64 accordingly (jaxtronomy/_runtime_config.py). Replace ad-hoc jax.config.update(...) calls with configure_jax_precision_for_runtime() across the codebase. Add Metal workarounds: a Metal-safe real-valued EPL series implementation and variants (function_/derivatives_/hessian_real_series) plus spatial (lax.conv) convolution fallback when FFT legalization fails on Metal. Normalize backend labels and add Metal-specific input sanitization/clip logic in the sampler and PSO initialization. Make the optax minimizer import lazy in fitting_sequence and add an install extra for macOS Metal in setup.py. Update README with macOS/Metal install and test instructions, and add tests for runtime config and EPL GPU-vs-CPU parity (including guarding legacy tests to use CPU by default on macOS). Misc: small numeric nan/inf clipping tweaks and improved error messages. These changes enable safer Metal GPU execution and provide parity checks while preserving CPU defaults for existing tests.
@aymgal
Copy link
Copy Markdown
Collaborator

aymgal commented Mar 9, 2026

Hi @trivialTZ , sorry to jump here but I am curious about the JAX+Metal support. To my knowledge, it seems jax_metal is not maintained anymore since 2024. So I was wondering if this is, on the long run, a good idea to support it. But perhaps I am missing something here and in that case, I would be interested to know more from your idea behind this PR, as running JAX on Metal sounds obviously very interesting in the first place.

@trivialTZ
Copy link
Copy Markdown
Author

trivialTZ commented Mar 11, 2026

Hi @aymgal , unfortunately, yes: on Mac, GPU support currently requires downgrading JAX to 0.5.0. Because of that, I would prefer this PR to remain an optional path for users on Mac who explicitly choose GPU, rather than something we depend on by default.

I also looked at the recent jax-mps work. It looks promising, but at least for now it seems more like an actively evolving community backend than something I would want to treat as a stable dependency for this PR. So the idea behind this PR is simply to preserve an optional Mac + GPU route where it works.

From what I tested, JAX 0.5.0 can handle all the features I currently need (multi-band PSO, etc.), and all tests pass in that setup, so I pushed the PR here.

@aymgal
Copy link
Copy Markdown
Collaborator

aymgal commented Mar 11, 2026

Thanks for your thoughts @trivialTZ ! Indeed jax-mps looks promising.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants