Skip to content

GalacticDynamics/jaxmore

Repository files navigation

jaxmore

There's more to JAX.

PyPI version PyPI platforms Actions status

This package provides some useful functionality that is missing in base JAX. Major features include:

  • vmap — a drop-in replacement for jax.vmap with static-arg/kwarg support and per-kwarg axis control.
  • bounded_while_loop — a reverse-mode-friendly, bounded while_loop implemented via lax.scan.

Installation

pip install jaxmore

Examples

vmap — static arguments and per-kwarg axis mapping

jaxmore.vmap is a drop-in replacement for jax.vmap. By default it behaves identically:

import jax.numpy as jnp
from jaxmore import vmap


def f(x, *, scale):
    return x * scale


vf = vmap(f)
vf(jnp.arange(3.0), scale=jnp.ones(3))  # Array([0., 1., 2.], dtype=float32)

Static args & kwargs — bake constants into a closure so they never cross the jax.jit boundary, reducing dispatch overhead:

import jax.numpy as jnp
from jaxmore import vmap


def mul(factor, x, *, offset):
    return factor * x + offset


vmul = vmap(mul, static_args=(3.0,), static_kw={"offset": 1.0})
print(vmul(jnp.arange(4.0)))  # Array([ 1.,  4.,  7., 10.], dtype=float32)

Per-kwarg axis control — map, broadcast, or ignore individual keyword arguments independently (not possible with jax.vmap):

import jax.numpy as jnp
from jaxmore import vmap


def h(x, *, a, b):
    return x * a + b


# 'a' is mapped along axis 0, 'b' is broadcast (not mapped)
vh = vmap(h, in_kw={"a": 0, "b": None})
print(vh(jnp.ones(3), a=jnp.array([1.0, 2.0, 3.0]), b=10.0))
# Array([11., 12., 13.], dtype=float32)

Broadcast a kwarg while mapping positional args:

import jax.numpy as jnp
from jaxmore import vmap


def f(x, *, scale):
    return x * scale


vf = vmap(f, in_kw=True, default_kw_axis=None)
print(vf(jnp.arange(3.0), scale=2.0))  # Array([0., 2., 4.], dtype=float32)

Optional JIT — JIT-compile the static-folded function before vmapping:

import jax.numpy as jnp
from jaxmore import vmap


def mul(factor, x, *, offset):
    return factor * x + offset


vmul = vmap(mul, static_args=(3.0,), static_kw={"offset": 1.0}, jit=True)
print(vmul(jnp.arange(4.0)))  # Array([ 1.,  4.,  7., 10.], dtype=float32)

bounded_while_loop

Simple loop over a scalar:

import jax.numpy as jnp
from jaxmore import bounded_while_loop


def cond_fn(x):
    return x < 5


def body_fn(x):
    return x + 1


result = bounded_while_loop(cond_fn, body_fn, jnp.asarray(0), max_steps=10)
print(result)  # Array(5, dtype=int32)

PyTree carry (tuple):

import jax.numpy as jnp
from jaxmore import bounded_while_loop


def cond_fn(state):
    x, _ = state
    return x < 3


def body_fn(state):
    x, y = state
    return x + 1, y * 2


result = bounded_while_loop(
    cond_fn, body_fn, (jnp.asarray(0), jnp.asarray(1)), max_steps=5
)
print(result)  # (Array(3, dtype=int32), Array(8, dtype=int32))