Releases: google/flax
0.12.0
Flax 0.12.0 includes many updates and some important breaking changes to the NNX API.
Breaking Changes
Pytree Strict Attributes
nnx.Pytree and therefore nnx.Module are now stricter with regards to attributes that contain Arrays and changing the status of attributes. For example, the code below now fails:
from flax import nnx
import jax
import jax.numpy as jnp
class Foo(nnx.Module):
def __init__(self, use_bias, rngs):
self.layers = [ # ERROR
nnx.Linear(3, 3, rngs=rngs) for _ in range(5)
]
self.bias = None # status = static
if use_bias:
self.bias = nnx.Param(rngs.params.uniform(3,)) # ERRORThis happens for two reasons:
- JAX pytree structures that contain Arrays now have to be marked with
nnx.data. Alternatively, if the container pytree is alistor adict, you can usennx.Listornnx.Dict, which additionally allow mixed "data" and "static" elements. - Attributes will no longer automatically change their status—this now has to be done explicitly using
nnx.dataornnx.static. Additionally, assigning Arrays or structures with Arrays to static attributes is now an error, as they will not automatically change to data.
To fix the above you can just create layers as a List Module which is automatically recognized as data, and be explicit about bias being a data attribute on the first assignment by using nnx.data:
class Foo(nnx.Module):
def __init__(self, use_bias, rngs):
self.layers = nnx.List([ # nnx.data also works but List is recommended
nnx.Linear(3, 3, rngs=rngs) for _ in range(5)
])
self.bias = nnx.data(None)
if use_bias:
self.bias = nnx.Param(rngs.params.uniform(3,))For more information check the Module & Pytree guide.
Eager Sharding
Variables will now eagerly shard their values when sharding_names metadata is provided. A mesh is required—it can be provided either via passing a mesh metadata attribute or setting the global mesh context via jax.set_mesh. This simplifies the process of sharding a Variable to construction time:
jax.config.update('jax_num_cpu_devices', 8)
mesh = jax.make_mesh((2, 4), ('data', 'model'))
with jax.set_mesh(mesh):
variable = nnx.Param(jnp.ones((16, 32)), sharding_names=(None, 'model'))
print(variable.value.sharding)Eager sharding will also occur when using the nnx.with_partitioning initializer decorator and will automatically extend to the Optimizer. This means that both model and optimizer will be sharded at construction without the need for the somewhat cumbersome nnx.get_partition_spec + jax.lax.with_sharding_constraint + nnx.update pattern:
with jax.set_mesh(mesh):
linear = nnx.Linear(
in_features=16, out_features=16, use_bias=False,
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(), (None, 'model')
),
rngs=nnx.Rngs(0),
)
optimizer = nnx.Optimizer(linear, optax.adam(1e-3), wrt=nnx.Param)
print(linear.kernel.value.sharding)
print(optimizer.opt_state[0].mu.kernel.value.sharding)For projects that currently rely on other means for sharding, eager sharding can be turned off by passing eager_sharding=False to the Variable constructor, either directly or through initializer decorators like nnx.with_partitioning:
linear = nnx.Linear(
in_features=16, out_features=16, use_bias=False,
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(), (None, 'model'), eager_sharding=False
),
rngs=nnx.Rngs(0),
)
optimizer = nnx.Optimizer(linear, optax.adam(1e-3), wrt=nnx.Param)
print(linear.kernel.value.sharding)
print(optimizer.opt_state[0].mu.kernel.value.sharding)Eager sharding can also be turned off globally via the flax_always_shard_variable config flag or the FLAX_ALWAYS_SHARD_VARIABLE environment variable:
import flax
flax.config.update('flax_always_shard_variable', False)For more information, check out the Variable eager sharding FLIP.
In-Place Operators No Longer Allowed
In-place operators will now raise an error. This is done as part of the push for Variables to be compatible with Tracer semantics:
w = nnx.Variable(jnp.array(0))
w += 1 # ERRORThe fix is to simply operate on the .value property instead:
w.value += 1All Changes
- Doc fix: remove dead link to pre-Orbax checkpointing. by @copybara-service[bot] in #4914
- Fix typo in unflatten docs by @copybara-service[bot] in #4918
- fix RNN by @copybara-service[bot] in #4917
- Update optimizer.py to support masked variable from optax. by @ywrt in #4904
- Added missing functions to graph.rst by @vfdev-5 in #4922
- Update flax/docs_nnx/guides/performance.md and .ipynb by @hanrach9 in #4919
- Added preferred_element_type arg to nnx.Linear*, nnx.Conv*, nnx.Einsum by @vfdev-5 in #4920
- Update README badges and remove invalid ones by @IvyZX in #4905
- static + pytree guide by @cgarciae in #4897
- fix mypy by @copybara-service[bot] in #4931
- Avoid passing non-boolean mask to
whereargument ofjax.numpyreductions. Non-boolean mask inputs have been deprecated for several releases, and will result in an error starting in JAX v0.8.0. by @copybara-service[bot] in #4923 - Ported nnx.PReLU from linen by @vfdev-5 in #4934
- Added nnx.scan docs and few minor docs fixes by @vfdev-5 in #4930
- add variables argument to nnx.clone by @cgarciae in #4945
- only copy dicts on State.getitem by @cgarciae in #4946
- always differentiate standalone Variables in nnx.grad by @cgarciae in #4947
- Implement instance norm in NNX by @mattbahr in #4939
- Automatically apply sharding constraints to sharded models by @IvyZX in #4844
- Add reference of flip doc to gspmd guide by @IvyZX in #4949
- Fixed nnx.is_data docstring rendering by @vfdev-5 in #4957
- expose pytree guide by @cgarciae in #4951
- fix toy examples by @cgarciae in #4952
- Explicitly cast attribute names to string before checking for private attributes. by @copybara-service[bot] in #4955
- add flax_hijax_variable flag by @cgarciae in #4953
- mark shard_map as implemented in transforms guide by @cgarciae in #4738
- improve Variable flatten by @cgarciae in #4954
- Minor typo fix in nnx.call docstring by @vfdev-5 in #4959
- allow split tuples in Rngs.fork by @cgarciae in #4958
- Fixed Gemma example using Gemma2 models by @vfdev-5 in #4830
- finish pytree guide by @cgarciae in #4929
- update bridge wrappers from maxtext by @cgarciae in #4937
- fix HashableMapping hash definition for mixed key types by @copybara-service[bot] in #4936
- Flax RNG guide for jax.jit: clarify rng outputs are shared but not inputs. by @copybara-service[bot] in #4956
- fix Variable pytree flatten by @copybara-service[bot] in #4962
- import PathParts from flax.typing by @cgarciae in #4966
- Correctly expose
flax.config.temp_flip_flagby @IvyZX in #4969 - raise on Variable inplace operators by @cgarciae in #4967
- Copybara import of the project: by @copybara-service[bot] in #4976
- update to version 0.12.0 by @cgarciae in #4982
- Minor typo fixes in flax gspmd guide by @vfdev-5 in #4970
- ignore uv.lock by @copybara-service[bot] in #4974
- [nnx] preserve the function's type information in jit by @cgarciae in #4981
- add Variable.set_metadata by @cgarciae in #4968
- propagate eager sharding by @cgarciae in #4983
New Contributors
Full Changelog: v0.11.2...v0.12.0
0.11.2
What's Changed
nnx.merge now doesn't create a copy of the Variables in the incoming states by default, meaning that the new merged structures holds references to the incoming Variables. This enables new patterns, for example its now possible to create models with the same state but with different runtime behavior:
model = SomeModel(...)
# create eval model
eval_model = nnx.merge(*nnx.split(model)) # same Variables, different structure
eval_model.eval()model and eval_model share the same Variables and are therefore kept in sync but have different runtime behavior, this avoids having to constantly mutate a single model back and forth between different runtime modes which can be error prone / cause unwanted recompilation.
To keep the old behavior use nnx.merge(..., copy=True).
PRs
- add Rngs random helpers by @cgarciae in #4876
- Fix re-export and docs for identity by @jlperla in #4850
- Fix ToLinen docstring return description by @mohsinm-dev in #4852
- Update doc build instructions and clean up unused packages by @IvyZX in #4885
- Improve docs related with dataclasses by @IvyZX in #4884
- Fix broken contributing documentation link by @mohsinm-dev in #4855
- Internal change by @copybara-service[bot] in #4886
- Fix string key preservation in replace_by_pure_dict by @mohsinm-dev in #4860
- Remove the need for Conv and ConvTranspose to know the precise batch size. by @copybara-service[bot] in #4877
- call jax's source_info_util.register_exclusion in flax's traceback_util.register_exclusion by @copybara-service[bot] in #4887
- Update typo in nnx.Optimizer by @codinfox in #4880
- Exposed split_rngs docstring in the docs_nnx by @vfdev-5 in #4846
- Pin sentencepiece version to 0.2.0 to fix head by @IvyZX in #4892
- Relax duplicate check to exclude non-string values such as PartitionSpec.UNCONSTRAINED, since those can be repeated. by @copybara-service[bot] in #4881
- add find_duplicates by @cgarciae in #4894
- Sharding API improvements (non breaking) by @IvyZX in #4893
- document jax.random shorthand methods by @cgarciae in #4899
- Optimiser was already instantiated using the model - 05_vae.py by @nenuadrian in #4857
- revert is_leaf logic in _check_carry_same_references by @copybara-service[bot] in #4903
- Doc fix: remove outdated advice on flax v0.6.10; it was released two years ago. by @copybara-service[bot] in #4910
- Fix bug when raising ScopeParamNotFoundError. by @copybara-service[bot] in #4898
- fix mypy on main by @cgarciae in #4909
- merge no copy Variables by @cgarciae in #4912
- update version to 0.11.2 by @copybara-service[bot] in #4915
New Contributors
- @mohsinm-dev made their first contribution in #4852
- @codinfox made their first contribution in #4880
- @nenuadrian made their first contribution in #4857
Full Changelog: v0.11.1...v0.11.2
v0.11.1
What's Changed
- Make
Sequential()be identity by @SobhanMP in #4796 - Add a JAX/Flax key concepts doc by @IvyZX in #4795
- miscellaneous improvements by @cgarciae in #4859
- Replace
jax.sharding.use_meshwithjax.set_mesh.jax.set_meshcan act as a global setter or a context manager. by @copybara-service[bot] in #4862 - Pytree and ArrayRef refactor by @cgarciae in #4863
- Add old property attributes for object->pytree rename. by @copybara-service[bot] in #4864
- Add BatchNorm layers to CNN in MNIST tutorial for improved training stability by @sanepunk in #4773
- Description by @copybara-service[bot] in #4866
- update and pop for dict by @cgarciae in #4869
- simplify nnx_basics by @cgarciae in #4868
- updates to version 0.11.1 by @cgarciae in #4878
New Contributors
Full Changelog: v0.11.0...v0.11.1
v0.11.0
v0.11.0 - Pytrees, MutableArrays, and more!
This version of Flax introduces some changes to improve interop with native JAX and adds support for the new jax.experimental.MutableArray. More on this soon! However, some breaking changes to align with the JAX way of doing things were necessary. Most code should remain intact, however, the following changes deviate from the current behavior:
Rngsin standard layers: all standard layers no longer hold a shared reference to therngsobject given in the constructor, instead they now keep afork-ed copy of theRngsorRngStreamobjects. This impacts Using Rngs in NNX Transforms and Loading Checkpoints with RNGs.- Optimizer Updates: the Optimizer abstraction no longer holds a reference to the
modelto avoid reference sharing, instead themodelmust be provided as the first argument toupdate. - Modules as Pytrees: Modules are now pytrees! This avoid unnecessary use of
splitandmergewhen interacting trivially with raw JAX transforms (state must still be manually propagated if not using MutableArrays, and referential transparency is still an issue). This affects when operating on Pytrees containing NNX Objects withjax.tree.*APIs.
Checkout the full NNX 0.10 to NNX 0.11 migration guide.
In the near future we'll share more information about new ways of using NNX with JAX transforms directly by leveraging the new Pytree and MutableArray support. Stay tuned!
What's Changed
- [nnx] mutable array p3 by @cgarciae in #4755
- [nnx] allow method calls in ToLinen by @cgarciae in #4808
- Internal change by @copybara-service[bot] in #4807
- Preserve sharding information in axes_scan by @copybara-service[bot] in #4806
- Deduplicate contributing and philosophy and move to main site by @IvyZX in #4809
- Fixed nnx.remat docstring rendering by @vfdev-5 in #4790
- Added a note to gemma guide about model's license consent on kaggle by @vfdev-5 in #4776
- [nnx] ToLinen add abtract_init flag by @cgarciae in #4813
- Modify NNX to use id(variable) instead of nnx.Variables as dictionary by @divyashreepathihalli in #4814
- Allow using LazyRngs for flax init/apply. by @copybara-service[bot] in #4818
- [nnx] remove VariableState by @cgarciae in #4800
- Fix failing CI jobs: trailing whitespace, deprecated
.typeusage by @vfdev-5 in #4823 - [nnx] fix Rngs dtype check by @cgarciae in #4820
- refactor: move usages of
.valueto[...]in modules_test.py by @lukeyeh in #4815 - Added training script for Gemma model by @vfdev-5 in #4822
- [nnx] add flax_pytree_module flag by @cgarciae in #4811
- create ModelAndOptimizer symbol by @copybara-service[bot] in #4849
- [nnx] remove Optimizer.model attribute by @cgarciae in #4842
- [nnx] add mutable array support in update by @cgarciae in #4851
- Migrate
transforms_test.pyfrom.valueto[...]by @lukeyeh in #4841 - 0.11.0 migration guide by @cgarciae in #4854
New Contributors
- @divyashreepathihalli made their first contribution in #4814
- @lukeyeh made their first contribution in #4815
Full Changelog: v0.10.7...v0.11.0
0.10.7
What's Changed
- Added identity export from JAX by @jlperla in #4652
- Fixes a bug in type annotations for scope.param (unbox=True should accept callable[..., T | AxisMEtadata[T]] and return T, while unbox=False should always return the same thing as what callable returning. by @copybara-service in #4727
- fix merge by @copybara-service in #4731
- [nnx] make Variable a pytree by @cgarciae in #4728
- [nnx] add JitWrapped API by @cgarciae in #4699
- Update JAX nightly index usage by @copybara-service in #4733
- [nnx] mutable array p1 by @cgarciae in #4715
- add dataclass by @copybara-service in #4739
- [flax] unconditionally register nnx.Variable as a pytree by @copybara-service in #4748
- Updated version of pre-commit-hooks in .pre-commit-config.yaml by @vfdev-5 in #4746
- Fixed docstring visibility for nnx.eval_shape by @vfdev-5 in #4747
- Added keep_rngs arg to MHA to optionally store rngs by @vfdev-5 in #4749
- MultiHeadAttention only keeps rngs if dropout_rate is positive by @copybara-service in #4750
- [nnx] mutable array p2 by @cgarciae in #4741
- Add in_kv_features argument to nnx.MultiHeadAttention, addressing #4756. by @copybara-service in #4757
- Fix broken link for Transforms guide by @nireekshak in #4763
- Minor improvements of lm1b_nnx example by @vfdev-5 in #4745
- Fix head CI tests by @IvyZX in #4764
- Fix typos by @nireekshak in #4725
- Check for leaves of type variablelib.Variable when getting sharding specs. by @copybara-service in #4769
- Fixes #1925 non-str dict keys not suppoted in module state by @muhrin in #4563
- Modified the Functional API link by @nireekshak in #4767
- Fix hardcoded link to filter guide in docs by @hamogu in #4768
- Fix bad doc links by @IvyZX in #4770
- revise axes_scan to flatten argument pytrees only once by @copybara-service in #4772
- Simplify ToNNX access of Linen module methods by @IvyZX in #4766
- Use
.input_formatsand.output_formatsin place of.input_layoutsand.output_layoutsrespectively. by @copybara-service in #4784 - Exposed OptState in nnx module by @vfdev-5 in #4788
- Fixes colab link for nnx docs by @vfdev-5 in #4775
- Internal changes by @copybara-service in #4786
- Fix typo in Flax
nnx_basicsdoc. by @copybara-service in #4781 - update version to 0.10.7 by @cgarciae in #4798
New Contributors
- @nireekshak made their first contribution in #4763
- @muhrin made their first contribution in #4563
- @hamogu made their first contribution in #4768
Full Changelog: v0.10.6...v0.10.7
0.10.6
What's Changed
- Sow top activations based on absolute value. by @copybara-service in #4670
- Add support for layer-specific rope scale factors. by @copybara-service in #4672
- Automatic model selection for Gemma 3 models. by @copybara-service in #4671
- Make LoRA's dtype arg useful by @IvyZX in #4681
- [NVIDIA] Support FP8 Einsum Op by @kaixih in #4686
- [nnx] remove deprecated APIs by @cgarciae in #4627
- Add
attention_biasparameter toMultiHeadDotProductAttention. by @copybara-service in #4694 - Unit tests for
attention_biasparameter toMultiHeadDotProductAttention. Add parameter to all overloads to make pytype happy. by @copybara-service in #4702 - Rollback of attention_bias parameter, because the change overrides the attention bias for injected attention functions. by @copybara-service in #4703
- Add custom einsum op to Einsum() by @IvyZX in #4705
- [nnx] refactor GraphDef by @cgarciae in #4630
- Make fully replicated array before saving checkpoints for examples that use pmap. by @copybara-service in #4707
- Fix CI by @cgarciae in #4716
- remove "nnx" collection in ToLinen by @copybara-service in #4708
- [nnx] flaxlib types by @cgarciae in #4639
- v0.10.6 by @cgarciae in #4724
Full Changelog: v0.10.5...v0.10.6
0.10.5
What's Changed
- [nnx] fix tabulate by @cgarciae in #4580
- Refactor bridge.Module tests from
wrappers_test.pyto another file. by @copybara-service in #4581 - Avoid calls to jnp.shape for non-array inputs. by @jakevdp in #4592
- remove Embed nan casting by @cgarciae in #4600
- Add QK Norm. by @copybara-service in #4594
- Util to let bridge module work with NNX submodules by @IvyZX in #4584
- Add configurable Query Pre Attention scalar. by @copybara-service in #4595
- Make RoPE Base Frequency configurable. by @copybara-service in #4596
- [nnx] pytrees are graph nodes by @cgarciae in #4547
- Add option to load checkpoints with transposed Gating Einsum. by @copybara-service in #4597
- add top_p sampling in gemma example by @copybara-service in #4591
- Fix position and name of Post Attention Norm. by @copybara-service in #4598
- Add Sow Config to from_params constructor. by @copybara-service in #4599
- bridge module with linen submodule by @IvyZX in #4604
- Dramatically speed up sampling compilation time by @copybara-service in #4574
- [nnx] improve grad docs by @cgarciae in #4588
- [nnx] add support for standalone Variables by @cgarciae in #4606
- add promote_dtype as a config option for multiple layers by @cgarciae in #4613
- Copybara import of the project: by @copybara-service in #4616
- Fixed typo in
beam_searchloop. by @copybara-service in #4615 - support swap model params in gemma sampler by @copybara-service in #4614
- Allow bridge module to have 'name' field by @IvyZX in #4619
- fix performance guide by @cgarciae in #4621
- Copybara import of the project: by @copybara-service in #4618
- Add REFLECT padding to convolution layer by @sarlinpe in #4553
- fix trace-level detection by @cgarciae in #4527
- Add attribute path customization to bridge modules by @IvyZX in #4624
- add reprlib max depth flag by @cgarciae in #4632
- Allow custom axis metadata annotation during transforms by @IvyZX in #4637
- [bridge module] Allow name arg to represent actual submodule path by @IvyZX in #4634
- [nnx] improve Variable proxy for binary operations by @cgarciae in #4641
- Fix module stack typing annotation. by @copybara-service in #4633
- Stop passing reduce_axes to jax.grad, jax.vjp, and jax.value_and_grad. by @copybara-service in #4617
- discord release webhook by @cgarciae in #4646
- [nnx] support Array leaves in graph nodes by @cgarciae in #4612
- Roll up package jax version and uv.lock by @IvyZX in #4648
- Use jax.nn.dot_product_attention when possible by @IvyZX in #4649
- Fix flaky vmap test tolerance. by @copybara-service in #4653
- Test runner ubuntu upgrade 24.04 by @IvyZX in #4659
- Fix lazy_init typo by @IvyZX in #4657
- deflake a test by @copybara-service in #4663
- v0.10.5 by @cgarciae in #4656
New Contributors
Full Changelog: v0.10.4...v0.10.5
Release 0.10.4
What's Changed
- update pypi publish by @cgarciae in #4538
- [nnx] register_variable_name refactor by @copybara-service in #4540
- added support to the accuracy metric for binary classification by @mattbahr in #4536
- [nnx] bridge Module by @cgarciae in #4542
- [nnx] copy _var_metadata by @copybara-service in #4548
- [bridge] fix unbox logic by @copybara-service in #4551
- Add
is_initializingAPI by @copybara-service in #4550 - [nnx] Add specific model typing for nnx.Optimizer by @marcelroed in #4470
- Add linen metadata conversion to linx by @IvyZX in #4552
- [bridge] improve Module context by @cgarciae in #4554
- Raise error if user uses 'name' in bridge module setup by @IvyZX in #4555
- Add deprecation warning to all
nnx.Statemethods by @IvyZX in #4561 - [nnx] add shard_map by @cgarciae in #4490
- Fix CI breakages from newest jax by @IvyZX in #4576
- [bridge] Set _initializing correctly and avoid return RNG states by @copybara-service in #4569
- v0.10.4 by @cgarciae in #4579
New Contributors
- @mattbahr made their first contribution in #4536
- @marcelroed made their first contribution in #4470
Full Changelog: v0.10.3...v0.10.4
Version 0.10.3
What's Changed
- Fix fori_loop and while_loop on multiple modules by @IvyZX in #4390
- Upgrade Flax readme to NNX by @8bitmp3 in #4386
- [nnx] add performance guide notebook by @cgarciae in #4384
- Automated Code Change by @copybara-service in #4393
- [nnx] optimize NodeDef.attributes by @cgarciae in #4399
- Fixed the broken link in haiku_to_flax.rst file by @tilakrayal in #4402
- [nnx] optimize Variable by @cgarciae in #4400
- Update Flax NNX Randomness by @8bitmp3 in #4279
- Remove the repeated methods in flax.nnx.Module documentation by @rajasekharporeddy in #4416
- Fixed the broken link in linen_to_nnx.rst by @tilakrayal in #4415
- [nnx] add FlatState by @cgarciae in #4410
- Update
async_checkpointer.pyreference by @emmanuel-ferdman in #4385 - Fix multiple links to Orbax documentation by @Matt-Hurd in #4364
- Update links in
Why Flax NNXdocumentation by @rajasekharporeddy in #4425 - [nnx] fix transforms guide by @cgarciae in #4421
- Add benchmark on state traversal, and a readme by @IvyZX in #4428
- Update Flax NNX performance guide by @8bitmp3 in #4401
- Create sharding via Partitioned.get_sharding() by @copybara-service in #4427
- Update Flax NNX vs JAX Transformations guide by @8bitmp3 in #4286
- Upgrade Flax NNX Gemma Sampling Inference doc by @8bitmp3 in #4325
- Update NNX
mergedocs in graph.py by @8bitmp3 in #4411 - Fix main and add nnx.fori_loop test by @cgarciae in #4472
- Upgrade Flax NNX Filters doc by @8bitmp3 in #4199
- Makes flax Modules more compatible with IPython auto-reload. by @copybara-service in #4420
- [nnx] RNN: add broadcast_rngs and state_axes APIs by @cgarciae in #4407
- Allow
nnx.bridge.variables.nnx_attrs_to_linen_varstakennx.VariableStateas argument. by @copybara-service in #4473 - [nnx] add state summaries for print and display by @cgarciae in #4438
- Copybara import of the project: by @copybara-service in #4475
- [nnx] add state summaries for print and display by @copybara-service in #4477
- CI: add scheduled test against nightly JAX releases by @jakevdp in #4478
- CI: pin actions to specific commits by @jakevdp in #4479
- [nnx] fix MultiMetric typing by @cgarciae in #4485
- [nnx] fix ToNNX linen_attributes update by @cgarciae in #4486
- Remove usages of orbax_utils.save_args_from_target, as this function does nothing (it used to control a checkpointing behavior that has since been optimized away). by @copybara-service in #4482
- [nnx] improve Module docs by @cgarciae in #4499
- Update einsum layer for Gemma example by @copybara-service in #4498
- [nnx] fix fiddle by @cgarciae in #4500
- Don't create param in normalization layers instead of create None-value params. by @copybara-service in #4501
- Rename variable string mapping utils and move them to variableslib by @IvyZX in #4503
- fix LoRA initialization error in nnx layer by @copybara-service in #4502
- Remove all
Param(None)lines from NNX by @IvyZX in #4504 - make gemma FFW LoRA friendly by @copybara-service in #4510
- Add
nnx.Module.perturbby @IvyZX in #4515 - [nnx] add tabulate by @cgarciae in #4493
- batch_norm.rst: == should be = by @cool-RR in #4524
- v0.10.3 by @cgarciae in #4525
New Contributors
- @tilakrayal made their first contribution in #4402
- @emmanuel-ferdman made their first contribution in #4385
- @Matt-Hurd made their first contribution in #4364
Full Changelog: v0.10.2...v0.10.3
Version 0.10.2
What's Changed
- Add
nnx.fori_loopby @IvyZX in #4353 - Linesearch (and lbfgs) support by @jlperla in #4351
- Upgrade Flax NNX Haiku Linen migration doc by @8bitmp3 in #4200
- Fix PRNG handling in
nn.jitundernn.scan. by @copybara-service in #4359 - support passing arguments directly to the struct.dataclass decorator by @copybara-service in #4275
- Avoid assert_array_equal for PRNG keys. by @copybara-service in #4363
- [nnx] support pure dicts by @cgarciae in #4352
- [nnx] add data parallel toy example by @cgarciae in #4354
- Add logical axis global context support for NNX by @IvyZX in #4350
- [nnx] fix ToLinen kwargs by @copybara-service in #4270
- [nnx] use HashableMapping instead of FrozenDict by @cgarciae in #4376
- [nnx] fix while_loop/fori_loop bug when sharing references by @cgarciae in #4379
- Add
flax.nnx.eval_shapedocstring by @8bitmp3 in #4374 - Setup the flaxlib in C++, using Meson and Nanobind. by @copybara-service in #4380
- Add
flax.nnx.rematdocstring by @8bitmp3 in #4373 - [nnx] add checkify by @cgarciae in #4381
- Lint flax.nnx.while_loop docstring by @8bitmp3 in #4371
- Lint flax.nnx.fori_loop docstring by @8bitmp3 in #4370
- [nnx] add some optimizations to graph.py by @cgarciae in #4377
- update version to 0.10.2 by @cgarciae in #4387
New Contributors
Full Changelog: v0.10.1...v0.10.2