Releases: HomebrewML/HeavyBall
v3.0.0
HeavyBall 3.0.0
Highlights
- Simplified public API:
Foreach*prefixes removed, short names are now the canonical classes - New optimizers:
HyperBallAdamW,MuonAdamW,LATHER,PSGDPRO LATHER, "Lie-group Adam Through Harmonic Eigenbasis Rotations", performs AdamW in the PSGD eigenbasisRoute-based param dispatch replaces manualSplitOptfor mixed-architecture optimizersScheduleFreeandMSAMmode switches are now idempotent (eval()twice is safe)- Higher-precision PSGD preconditioner updates
- New
consume_gradoption:step()clearsp.gradafter consuming it by default; setconsume_grad=Falseto keep gradients attached after the step orig_shapesis now an explicit documented optimizer argument; usecapture_param_shapes(...)before wrapping models with sharding backends that do not preserve original parameter shapestorch.compile-friendly step with automatic eager fallback for init/preconditioning
Release benchmarks
HeavyBall 3.0.0 was benchmarked against HeavyBall 2.0.0 and torch.optim with
benchmarks/bench_release_optimizers.py, with compiled AdamW step latency
dropping from 10.63 ms in HeavyBall 2.0.0 to 4.15 ms in HeavyBall 3.0.0, a 2.56x speedup.
Breaking changes
Class renames
Every Foreach* class is renamed to its short form. The old short-form aliases (which existed
in 2.x) keep working, only the Foreach* imports break.
| 2.x name | 3.x name |
|---|---|
ForeachAdamW |
AdamW |
ForeachNAdam |
NAdam |
ForeachAdEMAMix |
AdEMAMix |
ForeachAdamC |
AdamC |
ForeachRMSprop |
RMSprop |
ForeachSFAdamW |
SFAdamW |
ForeachADOPT |
ADOPT |
ForeachMuon |
Muon |
ForeachLaProp |
LaProp |
ForeachSignLaProp |
SignLaProp |
ForeachSOAP |
SOAP |
ForeachSOAPNAdam |
SOAPNAdam |
ForeachSOAPAdEMAMix |
SOAPAdEMAMix |
ForeachSOLP |
SOLP |
ForeachPSGDKron |
PSGDKron |
ForeachPSGDLRA |
PSGDLRA |
Removed optimizer classes
These were thin subclasses that only set a class-level default. Use the parent class with the
corresponding constructor argument instead.
| 2.x class | 3.x equivalent |
|---|---|
PaLMForeachSFAdamW / PaLMSFAdamW |
SFAdamW(..., palm=True) |
PaLMForeachSOAP / PaLMSOAP / PalmForEachSoap |
SOAP(..., palm=True) |
PrecondScheduleForeachSOAP / PrecondScheduleSOAP |
SOAP(..., use_precond_schedule=True) |
PrecondSchedulePaLMForeachSOAP / PrecondSchedulePaLMSOAP |
SOAP(..., palm=True, use_precond_schedule=True) |
ForeachPurePSGD / PurePSGD |
PSGDKron(..., exp_avg_input=False) |
ForeachCachedPSGDKron / CachedPSGDKron |
PSGDKron(...) (caching is now the default) |
ForeachDelayedPSGD / DelayedPSGD |
PSGDKron(..., delayed=True) |
ForeachCachedDelayedPSGDKron / CachedDelayedPSGDKron |
PSGDKron(..., delayed=True) |
ForeachCachedNewtonPSGD / NewtonPSGDKron |
PSGDKron(..., hessian_approx=True) |
NewtonHybrid2PSGDKron |
PSGDKron(..., hessian_approx=True, hvp_interval=2) |
ForeachDelayedPSGDLRA / DelayedPSGDLRA |
PSGDLRA(..., delayed=True) |
ForeachNewtonPSGDLRA / NewtonPSGDLRA |
PSGDLRA(..., hessian_approx=True) |
NewtonHybrid2PSGDLRA |
PSGDLRA(..., hessian_approx=True, hvp_interval=2) |
Renamed parameters
| 2.x parameter | 3.x parameter | Notes |
|---|---|---|
foreach |
multi_tensor |
Passing foreach emits a FutureWarning and remaps automatically |
Removed parameters
These raise TypeError if passed. They were either unused or replaced by better defaults.
| Parameter | Previously on | Notes |
|---|---|---|
stochastic_schedule |
SOAP, PSGDKron, PSGDLRA | Deterministic accumulation schedule is now the only mode |
normalize_grads |
SOAP variants | Was unused in the transform pipeline |
correct_bias |
SOAP variants | Was unused in the transform pipeline |
inverse_free |
PSGDKron | Use quad_torch or PSGDPRO for inverse-free PSGD |
adaptive |
PSGDKron | Removed |
Helper sampler kwargs
These compatibility kwargs were removed from heavyball.helpers samplers and now raise
TypeError.
| Class | Removed kwargs |
|---|---|
BoTorchSampler |
constraints_func, consider_running_trials |
HEBOSampler |
constant_liar |
ImplicitNaturalGradientSampler |
lr, warn_independent_sampling |
AutoSampler |
constraints_func |
Chainable API renames
| 2.x name | 3.x name |
|---|---|
Branch |
Parallel |
Behavioral changes
- ScheduleFree / MSAM
eval()/train(): Now idempotent. Callingeval()twice no
longer flips back to train mode. Both methods accept amodeargument matching
nn.Module.train(mode)and returnself. - Gradient lifetime:
consume_grad=Trueis available on all optimizers and clearsp.grad
duringstep()once the gradient has been consumed. Setconsume_grad=Falseif your code
reads gradients after stepping or relies on them remaining attached. - Sharded parameter shapes: Built-in optimizers now expose
orig_shapesexplicitly. Use
capture_param_shapes()before wrapping parameters if your sharding backend hides original
shapes. - PSGD dampening:
dampen_graddefault changed from2**-13to1e-9, and dampening
epsilon usestorch.finfo(float32).epsregardless of input dtype. This improves
preconditioner accuracy but may change convergence behavior.
Checkpoint migration
Use the migration CLI to convert 1.x or 2.x checkpoints:
python scripts/migrate_optimizer_state.py <checkpoint.pt> <OptimizerClass>Old class names (including all aliases listed above) are resolved automatically.
The foreach → multi_tensor key rename in param groups is handled automatically.
Upgrade checklist
- Replace
from heavyball import Foreach*with the short name (e.g.,ForeachAdamW→AdamW) - Replace
foreach=withmulti_tensor=in constructor calls - Replace removed subclass instantiations with parent + kwargs (see table above)
- Remove any
stochastic_schedule,normalize_grads,correct_bias,inverse_free, oradaptivekwargs - Replace
Branch(...)withParallel(...)in custom chainable code - Migrate checkpoints:
python scripts/migrate_optimizer_state.py <ckpt> heavyball.<Optimizer> - If you relied on
eval(); eval()toggling back to train mode, update your code - If your training loop reads
p.gradafterstep(), passconsume_grad=False - Remove obsolete compatibility kwargs from
heavyball.helperssamplers
Fix ECC
ECC, support latest torch, more tests
- ECC from https://arxiv.org/abs/2602.23349 was added for bf16+int8 states, see the example
- Replaced all torch._foreach with for loops, which traces differently in torch 2.10 and resolves torch-side bugs
- all tests/ are now included in CI
SCION + Split Optimizer
- SplitOptimizer, following Andreas Kirsch's research on continual learning (https://x.com/BlackHC/status/2001961535120568542)
- Approximate SCION
- Numpy 2.0.0 support
Configurable Division
By default, HeavyBall's division differs from the industry standard, potentially giving meaningfully different results for otherwise identical optimizer hyperparameters.
You can now set heavyball.utils.default_division_backend to one of
eps_clamp, HeavyBall's default (x / y.clamp(min=eps))eps_add, Standard, used by PyTorch, Optax and others (x / (y + eps))atan2, following Adam-Atan2 (atan2(x / scale, y) * scale) - may requireheavyball.utils.atan2_scaleto be to clamp to a different range of target valuesnan_to_0, resulting in(x/y).nan_to_num(0, 0, 0)
NAdam + AdEMAMix
This release focusses on adding new optimizers
- NAdam (following @tom-jod's research)
- AdEMAMix
SOAPNAdam- SOAP with NAdam in the eigenbasisSOAPAdEMAMix- SOAP with AdEMAMix in the eigenbasis
Note that this changes the previous SOAP infrastructure. SOAP variants manually created for the previous version will not work out of the box, but can be trivially converted.
Bugfixes, Memory reduction, Save/Restore
- torch autocasts psgd's internal step from int64->fp64, which caused a mismatch in states before/after loading
- an unbounded lru_cache, used to speed up parameter accesses, may have kept parameters around indefinitely
- with psgd, caution=True and foreach=False, caution was only applied on the first parameter
- psgd quad with bf16 parameters tried to multiply a bf16 matrix with an fp32 matrix
- psgd's
preconditioner_update_probabilitywas ignored if set to 0 - resulting in the default schedule being used - not all optimizers were exposed in
__all__ - psgd's scheduler was not stepped properly, causing the scheduler to remain at 100% precond update probability
- msign/thinky_polar_express (a zeroth-power backend) always returned fp32 tensors, where it should've been adaptive to the input dtype
- soap's initial eigh did not put the tensors back into their original dtype
- the built-in ema exited after the first empty parameter group, potentially skipping updates
- the built-in ema was updated once for each parameter group, causing different effective ema betas for different param group counts
- finite differences hvp did not divide by the epsilon scaling
- pointwise_lr_adaptation called lr_adaptation - not pointwise_lr_adaptation
- fused_hook may have processed 1-tensor models incorrectly
Stable Muon
Stability, Stability, Stability (and bug-fixes)
- Higher numerical stability in Newton-Schulz orthogonalization (affects: Muon)
- Higher accuracy in SVD computation (affects: PSGD)
- Advanced checkpointing (affects: old checkpoints, SOAP and PSGD)
- Reworked chainable backend, allowing more freedom in function composition (affects: custom optimizers)
For the full release notes and migration instructions, see here
- Benchmark
- Optimizers
- AdamC implemented (#65, by @Ryu1845 & #77, by @drexalt)
- Fixed all clipping algorithms with non-default clipping thresholds, add clipping tests (#73 & #74, by @alexjwilliams)
- More accurate NS5 iterations, backporting Muon research (#69, by @xTimeCrystal)
Fixed SOAP, HVP PSGD
Bugfixes:
- @francois-rozet fixed a severe convergence regression in SOAP. It's now faster and converges better than before (#42)
- ADOPT now correctly matches the paper, significantly improving its convergence
- FP64 storage and/or computation now works for more optimizers
Improvements:
- NewtonPSGD now supports exact HVP calculation instead of the previous approximate. (Handles BatchNorm better but doesn't support all architectures.)
"smart_one_diag"is a next-to-no-downsidesmemory_save_modefor PSGD. It reduces memory and compute cost compared tomemory_save_mode=Noneand improves convergence compared tomemory_save_mode="one_diag"*
*Instead of preconditioning all dimensions (memory_save_mode=None) or preconditioning all but the largest dimension (memory_save_mode="one_diag") we remove the largest dimension iff it's larger than the second largest. So, a Linear(128, 1024) will now create one 128x128 preconditioner (instead of 128x128 + 1024x1024, 8x as large as the parameters), while a Linear(128, 128) can still benefit from preconditioning both sides.