|
14 | 14 |
|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
| 17 | +import inspect |
17 | 18 | import typing as tp |
18 | 19 |
|
19 | 20 | import jax |
@@ -559,6 +560,115 @@ def eval_mode(node: A, /, *, only: filterlib.Filter = ..., **kwargs) -> A: |
559 | 560 | ) |
560 | 561 |
|
561 | 562 |
|
| 563 | +def _parse_docstring_args(doc_str: str) -> dict[str, str]: |
| 564 | + """Parses parameters from `Args:` section of a function docstring. |
| 565 | + Assumes Google style docstrings. Returns a dictionary with |
| 566 | + keys representing argument names and values representing descriptions. |
| 567 | + Each description has lines starting with 4 spaces. |
| 568 | + """ |
| 569 | + lines = doc_str.split("\n") |
| 570 | + |
| 571 | + # Get lines with the parameter names |
| 572 | + inds = [i for i, l in enumerate(lines) if l.startswith(" ") and not l.startswith(" ")] |
| 573 | + inds.append(len(lines)) |
| 574 | + out = dict() |
| 575 | + |
| 576 | + # Parse each argument |
| 577 | + for i in range(len(inds)-1): |
| 578 | + start, end = inds[i], inds[i+1] |
| 579 | + |
| 580 | + # Process first line for the description |
| 581 | + first_colon = lines[start].find(":") |
| 582 | + name = lines[start][:first_colon].strip() |
| 583 | + desc = [" "*4 + lines[start][first_colon+1:].strip()] |
| 584 | + |
| 585 | + # Append remaining description lines |
| 586 | + for j in range(start+1, end): |
| 587 | + desc.append(lines[j]) |
| 588 | + |
| 589 | + out[name] = "\n".join(desc) |
| 590 | + return out |
| 591 | + |
| 592 | + |
| 593 | + |
| 594 | +def set_mode_info(node: Module, /, *, only: filterlib.Filter = ...) -> str: |
| 595 | + """Provides information about the ``set_mode`` arguments for a module and all |
| 596 | + submodules. If no docstring is provided for a module's `set_mode`, this function |
| 597 | + puts the `set_mode` signature below the function. |
| 598 | +
|
| 599 | + Example:: |
| 600 | + >>> from flax import nnx |
| 601 | + ... |
| 602 | + >>> class CustomModel(nnx.Module): |
| 603 | + ... def __init__(self, *, rngs): |
| 604 | + ... self.mha = nnx.MultiHeadAttention(4, 8, 32, rngs=rngs) |
| 605 | + ... self.drop = nnx.Dropout(0.5, rngs=rngs) |
| 606 | + ... self.bn = nnx.BatchNorm(32, rngs=rngs) |
| 607 | + ... |
| 608 | + >>> model = CustomModel(rngs=nnx.Rngs(0)) |
| 609 | + >>> print(nnx.set_mode_info(model)) |
| 610 | + BatchNorm: |
| 611 | + use_running_average: bool | None = None |
| 612 | + if True, the stored batch statistics will be |
| 613 | + used instead of computing the batch statistics on the input. |
| 614 | + Dropout: |
| 615 | + deterministic: bool | None = None |
| 616 | + if True, disables dropout masking. |
| 617 | + MultiHeadAttention: |
| 618 | + deterministic: bool | None = None |
| 619 | + if True, the module is set to deterministic mode. |
| 620 | + decode: bool | None = None |
| 621 | + if True, the module is set to decode mode. |
| 622 | + batch_size: int | Shape | None = None |
| 623 | + the batch size to use for the cache. |
| 624 | + max_length: int | None = None |
| 625 | + the max length to use for the cache. |
| 626 | +
|
| 627 | + Args: |
| 628 | + node: the object to display ``set_mode`` information for. |
| 629 | + only: Filters to select the Modules to display information for. |
| 630 | + """ |
| 631 | + predicate = filterlib.to_predicate(only) |
| 632 | + classes: set[Module] = set() |
| 633 | + |
| 634 | + def _set_mode_info_fn(path, node): |
| 635 | + if hasattr(node, 'set_mode') and predicate(path, node): |
| 636 | + classes.add(node.__class__) |
| 637 | + return node |
| 638 | + |
| 639 | + graph.recursive_map(_set_mode_info_fn, node) |
| 640 | + |
| 641 | + class_list = sorted(list(classes), key=lambda x: x.__qualname__) |
| 642 | + out_str = [] |
| 643 | + for c in class_list: |
| 644 | + out_str.append(f"{c.__qualname__}:") |
| 645 | + sig = inspect.signature(c.set_mode) |
| 646 | + doc = inspect.getdoc(c.set_mode) |
| 647 | + |
| 648 | + # Parse docstring |
| 649 | + if isinstance(doc, str): |
| 650 | + start, end = doc.find("Args:\n"), doc.find("Returns:\n") |
| 651 | + if end == -1: |
| 652 | + end = len(doc) |
| 653 | + doc = doc[start+6:end] |
| 654 | + parsed_docstring = _parse_docstring_args(doc) |
| 655 | + |
| 656 | + # Generate output from signature and docstring |
| 657 | + skip_names = {"self", "args", "kwargs"} |
| 658 | + for name, param in sig.parameters.items(): |
| 659 | + if name in skip_names: |
| 660 | + continue |
| 661 | + |
| 662 | + if param.default is inspect.Parameter.empty: |
| 663 | + out_str.append(f" {name}: {param.annotation}") |
| 664 | + else: |
| 665 | + out_str.append(f" {name}: {param.annotation} = {param.default}") |
| 666 | + out_str.append(parsed_docstring[name]) |
| 667 | + else: |
| 668 | + out_str.append(f" set_mode{sig}") |
| 669 | + |
| 670 | + |
| 671 | + return "\n".join(out_str) |
562 | 672 |
|
563 | 673 | def first_from(*args: tp.Optional[A], error_msg: str) -> A: |
564 | 674 | """Return the first non-None argument. |
|
0 commit comments