This package provides some useful functionality that is missing in base JAX.
Major features include:
vmap— a drop-in replacement forjax.vmapwith static-arg/kwarg support and per-kwarg axis control.bounded_while_loop— a reverse-mode-friendly, boundedwhile_loopimplemented vialax.scan.
pip install jaxmorejaxmore.vmap is a drop-in replacement for jax.vmap. By default it behaves
identically:
import jax.numpy as jnp
from jaxmore import vmap
def f(x, *, scale):
return x * scale
vf = vmap(f)
vf(jnp.arange(3.0), scale=jnp.ones(3)) # Array([0., 1., 2.], dtype=float32)Static args & kwargs — bake constants into a closure so they never cross the
jax.jit boundary, reducing dispatch overhead:
import jax.numpy as jnp
from jaxmore import vmap
def mul(factor, x, *, offset):
return factor * x + offset
vmul = vmap(mul, static_args=(3.0,), static_kw={"offset": 1.0})
print(vmul(jnp.arange(4.0))) # Array([ 1., 4., 7., 10.], dtype=float32)Per-kwarg axis control — map, broadcast, or ignore individual keyword
arguments independently (not possible with jax.vmap):
import jax.numpy as jnp
from jaxmore import vmap
def h(x, *, a, b):
return x * a + b
# 'a' is mapped along axis 0, 'b' is broadcast (not mapped)
vh = vmap(h, in_kw={"a": 0, "b": None})
print(vh(jnp.ones(3), a=jnp.array([1.0, 2.0, 3.0]), b=10.0))
# Array([11., 12., 13.], dtype=float32)Broadcast a kwarg while mapping positional args:
import jax.numpy as jnp
from jaxmore import vmap
def f(x, *, scale):
return x * scale
vf = vmap(f, in_kw=True, default_kw_axis=None)
print(vf(jnp.arange(3.0), scale=2.0)) # Array([0., 2., 4.], dtype=float32)Optional JIT — JIT-compile the static-folded function before vmapping:
import jax.numpy as jnp
from jaxmore import vmap
def mul(factor, x, *, offset):
return factor * x + offset
vmul = vmap(mul, static_args=(3.0,), static_kw={"offset": 1.0}, jit=True)
print(vmul(jnp.arange(4.0))) # Array([ 1., 4., 7., 10.], dtype=float32)Simple loop over a scalar:
import jax.numpy as jnp
from jaxmore import bounded_while_loop
def cond_fn(x):
return x < 5
def body_fn(x):
return x + 1
result = bounded_while_loop(cond_fn, body_fn, jnp.asarray(0), max_steps=10)
print(result) # Array(5, dtype=int32)PyTree carry (tuple):
import jax.numpy as jnp
from jaxmore import bounded_while_loop
def cond_fn(state):
x, _ = state
return x < 3
def body_fn(state):
x, y = state
return x + 1, y * 2
result = bounded_while_loop(
cond_fn, body_fn, (jnp.asarray(0), jnp.asarray(1)), max_steps=5
)
print(result) # (Array(3, dtype=int32), Array(8, dtype=int32))