Add macOS Metal GPU backend support#79
Conversation
Introduce a runtime config helper to detect macOS+Metal and set jax_enable_x64 accordingly (jaxtronomy/_runtime_config.py). Replace ad-hoc jax.config.update(...) calls with configure_jax_precision_for_runtime() across the codebase. Add Metal workarounds: a Metal-safe real-valued EPL series implementation and variants (function_/derivatives_/hessian_real_series) plus spatial (lax.conv) convolution fallback when FFT legalization fails on Metal. Normalize backend labels and add Metal-specific input sanitization/clip logic in the sampler and PSO initialization. Make the optax minimizer import lazy in fitting_sequence and add an install extra for macOS Metal in setup.py. Update README with macOS/Metal install and test instructions, and add tests for runtime config and EPL GPU-vs-CPU parity (including guarding legacy tests to use CPU by default on macOS). Misc: small numeric nan/inf clipping tweaks and improved error messages. These changes enable safer Metal GPU execution and provide parity checks while preserving CPU defaults for existing tests.
for more information, see https://pre-commit.ci
|
Hi @trivialTZ , sorry to jump here but I am curious about the JAX+Metal support. To my knowledge, it seems |
|
Hi @aymgal , unfortunately, yes: on Mac, GPU support currently requires downgrading JAX to 0.5.0. Because of that, I would prefer this PR to remain an optional path for users on Mac who explicitly choose GPU, rather than something we depend on by default. I also looked at the recent jax-mps work. It looks promising, but at least for now it seems more like an actively evolving community backend than something I would want to treat as a stable dependency for this PR. So the idea behind this PR is simply to preserve an optional Mac + GPU route where it works. From what I tested, JAX 0.5.0 can handle all the features I currently need (multi-band PSO, etc.), and all tests pass in that setup, so I pushed the PR here. |
|
Thanks for your thoughts @trivialTZ ! Indeed |
Summary
with function_real_series, derivatives_real_series, and hessian_real_series variants
Test plan