Skip to content

Commit 12f1c15

Browse files
committed
Adding set_mode_info
1 parent 74985b2 commit 12f1c15

File tree

3 files changed

+174
-0
lines changed

3 files changed

+174
-0
lines changed

flax/nnx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from .module import M as M
4949
from .module import Module as Module
5050
from .module import set_mode as set_mode
51+
from .module import set_mode_info as set_mode_info
5152
from .module import train_mode as train_mode
5253
from .module import eval_mode as eval_mode
5354
from .module import iter_children as iter_children, iter_modules as iter_modules

flax/nnx/module.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import inspect
1718
import typing as tp
1819

1920
import jax
@@ -559,6 +560,115 @@ def eval_mode(node: A, /, *, only: filterlib.Filter = ..., **kwargs) -> A:
559560
)
560561

561562

563+
def _parse_docstring_args(doc_str: str) -> dict[str, str]:
564+
"""Parses parameters from `Args:` section of a function docstring.
565+
Assumes Google style docstrings. Returns a dictionary with
566+
keys representing argument names and values representing descriptions.
567+
Each description has lines starting with 4 spaces.
568+
"""
569+
lines = doc_str.split("\n")
570+
571+
# Get lines with the parameter names
572+
inds = [i for i, l in enumerate(lines) if l.startswith(" ") and not l.startswith(" ")]
573+
inds.append(len(lines))
574+
out = dict()
575+
576+
# Parse each argument
577+
for i in range(len(inds)-1):
578+
start, end = inds[i], inds[i+1]
579+
580+
# Process first line for the description
581+
first_colon = lines[start].find(":")
582+
name = lines[start][:first_colon].strip()
583+
desc = [" "*4 + lines[start][first_colon+1:].strip()]
584+
585+
# Append remaining description lines
586+
for j in range(start+1, end):
587+
desc.append(lines[j])
588+
589+
out[name] = "\n".join(desc)
590+
return out
591+
592+
593+
594+
def set_mode_info(node: Module, /, *, only: filterlib.Filter = ...) -> str:
595+
"""Provides information about the ``set_mode`` arguments for a module and all
596+
submodules. If no docstring is provided for a module's `set_mode`, this function
597+
puts the `set_mode` signature below the function.
598+
599+
Example::
600+
>>> from flax import nnx
601+
...
602+
>>> class CustomModel(nnx.Module):
603+
... def __init__(self, *, rngs):
604+
... self.mha = nnx.MultiHeadAttention(4, 8, 32, rngs=rngs)
605+
... self.drop = nnx.Dropout(0.5, rngs=rngs)
606+
... self.bn = nnx.BatchNorm(32, rngs=rngs)
607+
...
608+
>>> model = CustomModel(rngs=nnx.Rngs(0))
609+
>>> print(nnx.set_mode_info(model))
610+
BatchNorm:
611+
use_running_average: bool | None = None
612+
if True, the stored batch statistics will be
613+
used instead of computing the batch statistics on the input.
614+
Dropout:
615+
deterministic: bool | None = None
616+
if True, disables dropout masking.
617+
MultiHeadAttention:
618+
deterministic: bool | None = None
619+
if True, the module is set to deterministic mode.
620+
decode: bool | None = None
621+
if True, the module is set to decode mode.
622+
batch_size: int | Shape | None = None
623+
the batch size to use for the cache.
624+
max_length: int | None = None
625+
the max length to use for the cache.
626+
627+
Args:
628+
node: the object to display ``set_mode`` information for.
629+
only: Filters to select the Modules to display information for.
630+
"""
631+
predicate = filterlib.to_predicate(only)
632+
classes: set[Module] = set()
633+
634+
def _set_mode_info_fn(path, node):
635+
if hasattr(node, 'set_mode') and predicate(path, node):
636+
classes.add(node.__class__)
637+
return node
638+
639+
graph.recursive_map(_set_mode_info_fn, node)
640+
641+
class_list = sorted(list(classes), key=lambda x: x.__qualname__)
642+
out_str = []
643+
for c in class_list:
644+
out_str.append(f"{c.__qualname__}:")
645+
sig = inspect.signature(c.set_mode)
646+
doc = inspect.getdoc(c.set_mode)
647+
648+
# Parse docstring
649+
if isinstance(doc, str):
650+
start, end = doc.find("Args:\n"), doc.find("Returns:\n")
651+
if end == -1:
652+
end = len(doc)
653+
doc = doc[start+6:end]
654+
parsed_docstring = _parse_docstring_args(doc)
655+
656+
# Generate output from signature and docstring
657+
skip_names = {"self", "args", "kwargs"}
658+
for name, param in sig.parameters.items():
659+
if name in skip_names:
660+
continue
661+
662+
if param.default is inspect.Parameter.empty:
663+
out_str.append(f" {name}: {param.annotation}")
664+
else:
665+
out_str.append(f" {name}: {param.annotation} = {param.default}")
666+
out_str.append(parsed_docstring[name])
667+
else:
668+
out_str.append(f" set_mode{sig}")
669+
670+
671+
return "\n".join(out_str)
562672

563673
def first_from(*args: tp.Optional[A], error_msg: str) -> A:
564674
"""Return the first non-None argument.

tests/nnx/module_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,69 @@ def __call__(self, x):
759759
self.assertIn(str(expected_total_batch_stats), foo_repr[0])
760760
self.assertIn(str(expected_total_rng_states), foo_repr[0])
761761

762+
def test_set_mode_info(self):
763+
class Block(nnx.Module):
764+
def __init__(self, din, dout, *, rngs: nnx.Rngs):
765+
self.linear = nnx.Linear(din, dout, rngs=rngs)
766+
self.bn = nnx.BatchNorm(dout, rngs=rngs)
767+
self.dropout = nnx.Dropout(0.2, rngs=rngs)
768+
769+
def __call__(self, x):
770+
return nnx.relu(self.dropout(self.bn(self.linear(x))))
771+
772+
class Foo(nnx.Module):
773+
def __init__(self, rngs: nnx.Rngs):
774+
self.block1 = Block(32, 128, rngs=rngs)
775+
self.block2 = Block(128, 10, rngs=rngs)
776+
777+
def __call__(self, x):
778+
return self.block2(self.block1(x))
779+
780+
obj = Foo(rngs=nnx.Rngs(0))
781+
info_str = nnx.set_mode_info(obj)
782+
self.assertEqual(info_str.count("BatchNorm:"), 1)
783+
self.assertEqual(info_str.count("Dropout:"), 1)
784+
785+
def test_set_mode_info_with_filter(self):
786+
class Block(nnx.Module):
787+
def __init__(self, din, dout, *, rngs: nnx.Rngs):
788+
self.linear = nnx.Linear(din, dout, rngs=rngs)
789+
self.bn = nnx.BatchNorm(dout, rngs=rngs)
790+
self.dropout = nnx.Dropout(0.2, rngs=rngs)
791+
792+
def __call__(self, x):
793+
return nnx.relu(self.dropout(self.bn(self.linear(x))))
794+
795+
obj = Block(4, 8, rngs=nnx.Rngs(0))
796+
info_str = nnx.set_mode_info(obj, only=nnx.Dropout)
797+
self.assertIn("Dropout:", info_str)
798+
self.assertNotIn("BatchNorm:", info_str)
799+
800+
info_str = nnx.set_mode_info(obj, only=nnx.MultiHeadAttention)
801+
self.assertEmpty(info_str)
802+
803+
def test_set_mode_info_with_custom_set_mode(self):
804+
class Block(nnx.Module):
805+
def __init__(self, *, rngs: nnx.Rngs):
806+
pass
807+
808+
def __call__(self, x):
809+
return x
810+
811+
def set_mode(self, arg1: bool | None = None, arg2: int | None = None, **kwargs) -> dict:
812+
"""Example set mode test. This follows Google style docstrings.
813+
814+
Args:
815+
arg1: The first argument.
816+
arg2: The second argument.
817+
This has two lines.
818+
"""
819+
return kwargs
820+
821+
obj = Block(rngs=nnx.Rngs(0))
822+
info_str = nnx.set_mode_info(obj)
823+
self.assertEqual(f"{obj.__class__.__qualname__}:\n arg1: bool | None = None\n The first argument.\n arg2: int | None = None\n The second argument.\n This has two lines.", info_str)
824+
762825

763826
class TestModuleDataclass:
764827
def test_basic(self):

0 commit comments

Comments
 (0)