Skip to content

Releases: google/flax

0.12.0

25 Sep 23:58

Choose a tag to compare

Flax 0.12.0 includes many updates and some important breaking changes to the NNX API.

Breaking Changes

Pytree Strict Attributes

nnx.Pytree and therefore nnx.Module are now stricter with regards to attributes that contain Arrays and changing the status of attributes. For example, the code below now fails:

from flax import nnx
import jax
import jax.numpy as jnp

class Foo(nnx.Module):
  def __init__(self, use_bias, rngs):
    self.layers = [  # ERROR
      nnx.Linear(3, 3, rngs=rngs) for _ in range(5)
    ]
    self.bias = None # status = static
    if use_bias:
      self.bias = nnx.Param(rngs.params.uniform(3,)) # ERROR

This happens for two reasons:

  1. JAX pytree structures that contain Arrays now have to be marked with nnx.data. Alternatively, if the container pytree is a list or a dict, you can use nnx.List or nnx.Dict, which additionally allow mixed "data" and "static" elements.
  2. Attributes will no longer automatically change their status—this now has to be done explicitly using nnx.data or nnx.static. Additionally, assigning Arrays or structures with Arrays to static attributes is now an error, as they will not automatically change to data.

To fix the above you can just create layers as a List Module which is automatically recognized as data, and be explicit about bias being a data attribute on the first assignment by using nnx.data:

class Foo(nnx.Module):
  def __init__(self, use_bias, rngs):
    self.layers = nnx.List([  # nnx.data also works but List is recommended
      nnx.Linear(3, 3, rngs=rngs) for _ in range(5)
    ])
    self.bias = nnx.data(None)
    if use_bias:
      self.bias = nnx.Param(rngs.params.uniform(3,))

For more information check the Module & Pytree guide.

Eager Sharding

Variables will now eagerly shard their values when sharding_names metadata is provided. A mesh is required—it can be provided either via passing a mesh metadata attribute or setting the global mesh context via jax.set_mesh. This simplifies the process of sharding a Variable to construction time:

jax.config.update('jax_num_cpu_devices', 8)
mesh = jax.make_mesh((2, 4), ('data', 'model'))

with jax.set_mesh(mesh):
  variable = nnx.Param(jnp.ones((16, 32)), sharding_names=(None, 'model'))
  
print(variable.value.sharding)

Eager sharding will also occur when using the nnx.with_partitioning initializer decorator and will automatically extend to the Optimizer. This means that both model and optimizer will be sharded at construction without the need for the somewhat cumbersome nnx.get_partition_spec + jax.lax.with_sharding_constraint + nnx.update pattern:

with jax.set_mesh(mesh):
  linear = nnx.Linear(
    in_features=16, out_features=16, use_bias=False,
    kernel_init=nnx.with_partitioning(
      nnx.initializers.lecun_normal(), (None, 'model')
    ),
    rngs=nnx.Rngs(0),
  )
  optimizer = nnx.Optimizer(linear, optax.adam(1e-3), wrt=nnx.Param)
  
print(linear.kernel.value.sharding)
print(optimizer.opt_state[0].mu.kernel.value.sharding)

For projects that currently rely on other means for sharding, eager sharding can be turned off by passing eager_sharding=False to the Variable constructor, either directly or through initializer decorators like nnx.with_partitioning:

linear = nnx.Linear(
  in_features=16, out_features=16, use_bias=False,
  kernel_init=nnx.with_partitioning(
    nnx.initializers.lecun_normal(), (None, 'model'), eager_sharding=False
  ),
  rngs=nnx.Rngs(0),
)
optimizer = nnx.Optimizer(linear, optax.adam(1e-3), wrt=nnx.Param)
  
print(linear.kernel.value.sharding)
print(optimizer.opt_state[0].mu.kernel.value.sharding)

Eager sharding can also be turned off globally via the flax_always_shard_variable config flag or the FLAX_ALWAYS_SHARD_VARIABLE environment variable:

import flax
flax.config.update('flax_always_shard_variable', False)

For more information, check out the Variable eager sharding FLIP.

In-Place Operators No Longer Allowed

In-place operators will now raise an error. This is done as part of the push for Variables to be compatible with Tracer semantics:

w = nnx.Variable(jnp.array(0))
w += 1  # ERROR

The fix is to simply operate on the .value property instead:

w.value += 1

All Changes

New Contributors

Full Changelog: v0.11.2...v0.12.0

0.11.2

28 Aug 17:55

Choose a tag to compare

What's Changed

nnx.merge now doesn't create a copy of the Variables in the incoming states by default, meaning that the new merged structures holds references to the incoming Variables. This enables new patterns, for example its now possible to create models with the same state but with different runtime behavior:

model = SomeModel(...)
# create eval model
eval_model = nnx.merge(*nnx.split(model))  # same Variables, different structure
eval_model.eval()

model and eval_model share the same Variables and are therefore kept in sync but have different runtime behavior, this avoids having to constantly mutate a single model back and forth between different runtime modes which can be error prone / cause unwanted recompilation.

To keep the old behavior use nnx.merge(..., copy=True).

PRs

New Contributors

Full Changelog: v0.11.1...v0.11.2

v0.11.1

08 Aug 21:25

Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.11.0...v0.11.1

v0.11.0

29 Jul 21:04

Choose a tag to compare

v0.11.0 - Pytrees, MutableArrays, and more!

This version of Flax introduces some changes to improve interop with native JAX and adds support for the new jax.experimental.MutableArray. More on this soon! However, some breaking changes to align with the JAX way of doing things were necessary. Most code should remain intact, however, the following changes deviate from the current behavior:

  • Rngs in standard layers: all standard layers no longer hold a shared reference to the rngs object given in the constructor, instead they now keep a fork-ed copy of the Rngs or RngStream objects. This impacts Using Rngs in NNX Transforms and Loading Checkpoints with RNGs.
  • Optimizer Updates: the Optimizer abstraction no longer holds a reference to the model to avoid reference sharing, instead the model must be provided as the first argument to update.
  • Modules as Pytrees: Modules are now pytrees! This avoid unnecessary use of split and merge when interacting trivially with raw JAX transforms (state must still be manually propagated if not using MutableArrays, and referential transparency is still an issue). This affects when operating on Pytrees containing NNX Objects with jax.tree.* APIs.

Checkout the full NNX 0.10 to NNX 0.11 migration guide.

In the near future we'll share more information about new ways of using NNX with JAX transforms directly by leveraging the new Pytree and MutableArray support. Stay tuned!

What's Changed

New Contributors

Full Changelog: v0.10.7...v0.11.0

0.10.7

02 Jul 06:09

Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.10.6...v0.10.7

0.10.6

23 Apr 20:26

Choose a tag to compare

What's Changed

Full Changelog: v0.10.5...v0.10.6

0.10.5

31 Mar 15:17

Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.10.4...v0.10.5

Release 0.10.4

27 Feb 00:10

Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.10.3...v0.10.4

Version 0.10.3

10 Feb 17:33

Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.10.2...v0.10.3

Version 0.10.2

19 Nov 19:58

Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.10.1...v0.10.2