-
Notifications
You must be signed in to change notification settings - Fork 755
Fix bound method auto-unbinding for NNX transforms #5055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix bound method auto-unbinding for NNX transforms #5055
Conversation
- Add _resolve_bound_callable helper to detect and unbind Module methods - Update all transforms (remat, jit, grad, value_and_grad, vmap, etc.) to handle bound methods - Add comprehensive test suite for bound method functionality - Resolves issue google#5053 TraceContextError with Module methods
|
@mohsinm-dev I discussed the issue of nnx transforms on bound methods with the flax maintainers. They think there's just too many subtle bugs that might arise if we try to support all the different nnx transformations. All the different flavors of positional arguments (in_specs, static_argnums, in_axes, nondiff_argnums, etc) would need to be modified, resulting in a pretty substantial chunk of code. Instead, the consensus was that applying nnx transforms to bound methods should result in an error (similar to what you currently have for |
|
@samanklesaria Thanks for checking with the team that makes sense. I’ll update the PR to stop supporting bound-method callables across nnx transforms and instead raise a clear, consistent error (like nnx.scan already does) whenever a bound Module method is passed to nnx.grad/value_and_grad, nnx.remat, nnx.jit, nnx.vmap, nnx.pmap, nnx.checkify, and nnx.eval_shape. The error will explain the two supported patterns with examples: (1) use the decorator form on the method (@nnx.remat, @nnx.jit, etc.), which naturally operates on the unbound function, or (2) pass the unbound method No worries about the work pivot, exploring the argnum complexity was valuable for understanding why bound method support is problematic |
Resolves issue #5053 by automatically detecting and unbinding Module methods in JAX transforms.