Skip to content

Conversation

@mohsinm-dev
Copy link
Contributor

Resolves issue #5053 by automatically detecting and unbinding Module methods in JAX transforms.

  • Adds helper functions to detect bound methods and convert them to unbound callables
  • Updates all NNX transforms to handle bound methods seamlessly
  • Includes comprehensive test coverage for the new functionality
  • Maintains full backwards compatibility

- Add _resolve_bound_callable helper to detect and unbind Module methods
- Update all transforms (remat, jit, grad, value_and_grad, vmap, etc.) to handle bound methods
- Add comprehensive test suite for bound method functionality
- Resolves issue google#5053 TraceContextError with Module methods
@samanklesaria
Copy link
Collaborator

samanklesaria commented Oct 29, 2025

@mohsinm-dev I discussed the issue of nnx transforms on bound methods with the flax maintainers. They think there's just too many subtle bugs that might arise if we try to support all the different nnx transformations. All the different flavors of positional arguments (in_specs, static_argnums, in_axes, nondiff_argnums, etc) would need to be modified, resulting in a pretty substantial chunk of code. Instead, the consensus was that applying nnx transforms to bound methods should result in an error (similar to what you currently have for nnx.scan) encouraging users to use decorators or pass self arguments explicitly (as in nnx.transform(Model.method)(model, x) rather than nnx.transform(model.method)(x)). Apologies for not figuring this out before you put in all the work for the PR! Would you be willing to modify this PR to produce the relevant errors instead?

@mohsinm-dev
Copy link
Contributor Author

mohsinm-dev commented Oct 29, 2025

@samanklesaria Thanks for checking with the team that makes sense. I’ll update the PR to stop supporting bound-method callables across nnx transforms and instead raise a clear, consistent error (like nnx.scan already does) whenever a bound Module method is passed to nnx.grad/value_and_grad, nnx.remat, nnx.jit, nnx.vmap, nnx.pmap, nnx.checkify, and nnx.eval_shape. The error will explain the two supported patterns with examples: (1) use the decorator form on the method (@nnx.remat, @nnx.jit, etc.), which naturally operates on the unbound function, or (2) pass the unbound method
explicitly and supply the instance as the first argument (e.g., nnx.grad(Model.block)(model, x) rather than nnx.grad(model.block)(x)).

No worries about the work pivot, exploring the argnum complexity was valuable for understanding why bound method support is problematic

@samanklesaria samanklesaria requested a review from vfdev-5 October 30, 2025 20:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants