Skip to content

Commit d6f189b

Browse files
committed
docs: Add documentation on how complex numerics works
1 parent bb14415 commit d6f189b

8 files changed

Lines changed: 336 additions & 3 deletions

File tree

docsrc/contributors/complex_number_support.rst

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,218 @@ Key Implementation Invariants
146146
Nested submodule parameter names (e.g. ``layers.0.weight``) must have ``.``
147147
replaced with ``__`` before registration.
148148

149+
The Decomposition System — How It Is Built
150+
-------------------------------------------
151+
152+
The rewriter is split across two classes and wired together by a lightweight
153+
dispatch mechanism. This section walks through each piece and explains the
154+
design decisions.
155+
156+
ComplexOpDetector — Subgraph Discovery
157+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
158+
159+
``ComplexOpDetector`` walks the graph to find the set of nodes that participate
160+
in complex arithmetic.
161+
162+
``node_include_in_subgraph``
163+
""""""""""""""""""""""""""""
164+
165+
A node is included in a complex subgraph if:
166+
167+
1. Its output dtype is ``complex64`` or ``complex128`` (``is_complex_dtype``), **or**
168+
2. Any of its inputs are complex (``has_complex_input``).
169+
170+
The second condition is necessary to catch real-output ops — ``abs``, ``angle``,
171+
``real``, ``imag`` — whose inputs are complex. These must be rewritten alongside
172+
the rest of the subgraph even though their outputs are real.
173+
174+
``subgraph_from_anchor``
175+
""""""""""""""""""""""""
176+
177+
For ``view_as_real``-bounded subgraphs, detection starts at a ``view_as_real``
178+
*anchor* node and performs a backward BFS:
179+
180+
.. code-block:: text
181+
182+
view_as_real ← mul (complex) ← reshape ← placeholder (complex)
183+
↑ anchor ↑ subgraph ↑ subgraph ↑ input
184+
185+
At each step, if an upstream node satisfies ``node_include_in_subgraph`` it is
186+
added to the subgraph; otherwise it becomes an *input node* (the boundary). The
187+
result is a ``ComplexSubGraphInfo`` containing anchor nodes, subgraph nodes, and
188+
input nodes.
189+
190+
After collection the subgraph is **sorted in topological order** (by position in
191+
the graph's node list). This is critical: without it a ``mul`` node could be
192+
processed before its ``sin`` or ``cos`` operands, causing the rewriter to see the
193+
original complex node instead of the already-rewritten real node.
194+
195+
``find_complex_op_subgraphs`` and subgraph merging
196+
"""""""""""""""""""""""""""""""""""""""""""""""""""
197+
198+
When a model has multiple ``view_as_real`` anchors that share upstream nodes
199+
(e.g. ``xq_out`` and ``xk_out`` in a RoPE layer both descend from the same
200+
``freqs_cis`` placeholder), their subgraphs would otherwise be detected
201+
separately. ``find_complex_op_subgraphs`` merges overlapping subgraphs by
202+
set intersection so each node is rewritten exactly once.
203+
204+
``find_all_complex_subgraphs`` — unbounded complex ops
205+
"""""""""""""""""""""""""""""""""""""""""""""""""""""""
206+
207+
Some models produce a complex tensor as a graph *output* without passing it
208+
through ``view_as_real``. ``find_all_complex_subgraphs`` is a forward scan that
209+
collects every ``call_function`` node with a complex output, regardless of
210+
anchoring. The resulting subgraph is processed the same way as an
211+
anchor-bounded one.
212+
213+
ComplexGraphRewriter — Dispatch-Based Rewriting
214+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
215+
216+
``ComplexGraphRewriter`` is decorated with ``@_register_unpackers``, which at
217+
class-definition time scans every method for the ``@_complex_unpacker(op, ...)``
218+
decorator and builds a ``cls._DISPATCH`` dictionary mapping aten ops to rewrite
219+
methods.
220+
221+
.. code-block:: python
222+
223+
@_complex_unpacker(torch.ops.aten.mul.Tensor)
224+
def _rewrite_mul(self, node: Node, b: SubgraphBuilder, ...):
225+
...
226+
227+
The entry point ``rewrite_subgraph_nodes`` iterates over the (topologically
228+
ordered) subgraph nodes and for each node:
229+
230+
1. Looks up ``node.target`` in ``_DISPATCH``.
231+
2. If found, calls the corresponding rewrite method.
232+
3. If not found but the op is in ``_ELEMENTWISE_SAFE``, skips it (the op applies
233+
independently to every scalar, so the ``(..., 2)`` real layout is already
234+
correct).
235+
4. Otherwise logs a warning and leaves the node unchanged.
236+
237+
``_ELEMENTWISE_SAFE``
238+
"""""""""""""""""""""
239+
240+
The ``_ELEMENTWISE_SAFE`` set contains ops that apply to every element of the
241+
tensor independently — ``add.Tensor``, ``sub.Tensor``, ``neg``, ``mul.Scalar``,
242+
``clone``, ``where``, etc. On the ``(..., 2)`` real layout these are already
243+
correct: adding two complex tensors element-wise is the same as adding their
244+
real and imaginary parts independently.
245+
246+
Notably **excluded** from this set:
247+
248+
* ``permute.default`` — must append the trailing real/imag dim index.
249+
* ``add.Scalar`` / ``sub.Scalar`` — a scalar added to a complex number only
250+
shifts the real part; on the ``(..., 2)`` layout both parts would be shifted.
251+
* ``reshape`` / ``view`` — shape arguments need updating for the extra ``2`` dim.
252+
253+
Complex Multiply Decomposition
254+
"""""""""""""""""""""""""""""""
255+
256+
The most important rewrite is ``mul.Tensor`` between two complex operands.
257+
The rewriter calls ``complex_mul_replacement``:
258+
259+
.. code-block:: python
260+
261+
# inputs a, b have shape (..., 2) — last dim is [real, imag]
262+
re_a = select(a, -1, 0); im_a = select(a, -1, 1)
263+
re_b = select(b, -1, 0); im_b = select(b, -1, 1)
264+
real_out = re_a * re_b - im_a * im_b # ac - bd
265+
imag_out = re_a * im_b + im_a * re_b # ad + bc
266+
result = stack([real_out, imag_out], dim=-1)
267+
268+
Each step is inserted via a ``SubgraphBuilder`` anchored at the ``mul`` node,
269+
so all six new nodes appear immediately after it in topological order. The
270+
original ``mul`` node is then replaced and erased.
271+
272+
See :ref:`subgraph_builder` for more on how ``SubgraphBuilder`` manages
273+
cursor-based insertion.
274+
275+
The ``originally_complex`` Invariant
276+
"""""""""""""""""""""""""""""""""""""
277+
278+
Input replacement (Stage 2) converts complex ``placeholder`` nodes to
279+
``float32``. After that, ``is_complex_dtype(node)`` returns ``False`` for those
280+
nodes even though they logically represent complex quantities.
281+
282+
To avoid missed rewrites, the rewriter records the set of nodes that were complex
283+
*before any rewrites* in ``originally_complex``. The ``mul.Tensor`` dispatch
284+
handler only triggers the full complex-multiply decomposition when the ``mul``
285+
node appears in ``originally_complex``; real multiplies that happen to follow a
286+
complex input (e.g. an ``abs`` followed by a real-valued scale) are left alone.
287+
288+
FakeTensorMode Reuse for Dynamic Shapes
289+
"""""""""""""""""""""""""""""""""""""""""
290+
291+
When inserting a new ``placeholder`` for a complex input, the pass must populate
292+
``meta["val"]`` with a ``FakeTensor`` of the new real shape. Using a fresh
293+
``FakeTensorMode()`` would create a *new* ``ShapeEnv``, which is incompatible
294+
with the one that ``torch.export`` used to encode dynamic shape constraints
295+
(SymInt ranges).
296+
297+
The fix is to extract the ``FakeTensorMode`` from the *original* placeholder's
298+
``meta["val"].fake_mode`` and reuse it. The new fake tensor is then constructed
299+
by appending a concrete ``2`` to the symbolic shape list:
300+
301+
.. code-block:: python
302+
303+
orig_fake = input_node.meta["val"]
304+
sym_shape = list(orig_fake.shape) + [2]
305+
with orig_fake.fake_mode:
306+
fake_tensor = torch.empty(sym_shape, dtype=new_dtype, device=device)
307+
308+
This preserves all SymInt identity across the graph and keeps
309+
dynamic-shape exports working correctly.
310+
311+
Entry Point: ``complex_graph_detection``
312+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
313+
314+
The public entry point called by the lowering pipeline is
315+
``complex_graph_detection(gm, settings)``. It:
316+
317+
1. Instantiates ``ComplexOpDetector`` and ``ComplexGraphRewriter``.
318+
2. Calls ``find_complex_op_subgraphs`` anchored on ``view_as_real`` to find
319+
bounded complex subgraphs.
320+
3. Calls ``find_all_complex_subgraphs`` for any remaining complex nodes that
321+
are not ``view_as_real``-bounded.
322+
4. For each subgraph:
323+
324+
a. Calls ``replace_input_node`` on every boundary input node (Stage 2).
325+
b. Calls ``rewrite_subgraph_nodes`` on the ordered subgraph (Stage 3).
326+
c. Calls ``clean_up_graph_after_modifications`` to remove dead nodes.
327+
328+
5. Returns the modified ``GraphModule``.
329+
330+
Adding New Op Rewrites
331+
^^^^^^^^^^^^^^^^^^^^^^^
332+
333+
To teach the rewriter about a new complex op, add a method to
334+
``ComplexGraphRewriter`` tagged with ``@_complex_unpacker``:
335+
336+
.. code-block:: python
337+
338+
@_complex_unpacker(torch.ops.aten.my_new_op.default)
339+
def _rewrite_my_new_op(
340+
self,
341+
node: Node,
342+
originally_complex: set,
343+
) -> None:
344+
inp = node.args[0]
345+
with SubgraphBuilder(self.gm.graph, node) as b:
346+
re = b(torch.ops.aten.select.int, inp, -1, 0)
347+
im = b(torch.ops.aten.select.int, inp, -1, 1)
348+
result = b(my_real_impl, re, im)
349+
node.replace_all_uses_with(result)
350+
self.gm.graph.erase_node(node)
351+
352+
``@_register_unpackers`` (applied to the class) picks up the new entry
353+
automatically at import time — no other registration is required.
354+
355+
If the new op is elementwise-safe on the ``(..., 2)`` layout (i.e. it acts
356+
independently on every scalar), add it to ``_ELEMENTWISE_SAFE`` instead.
357+
149358
Related
150359
-------
151360

152361
* :ref:`lowering` — the complex rewrite is a lowering pass.
362+
* :ref:`subgraph_builder` — the ``SubgraphBuilder`` helper used in every rewrite method.
153363
* :ref:`lowering_passes_catalog` — pass ordering and management.

docsrc/tutorials/advanced_usage.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ Advanced Usage
22
==============
33

44
Step-by-step tutorials covering engine caching, quantization, custom kernels,
5-
dynamic shapes, weight streaming, debugging, and more.
5+
dynamic shapes, weight streaming, debugging, complex numerics, and more.
66

77
.. toctree::
88
:maxdepth: 2
@@ -14,5 +14,6 @@ dynamic shapes, weight streaming, debugging, and more.
1414
weight_refit/index
1515
runtime_opt/index
1616
deployment/index
17+
complex_numerics/index
1718
Example: Distributed Inference <_rendered_examples/distributed_inference/index>
1819
../indices/supported_ops

docsrc/tutorials/deployment/complex_tensors.rst renamed to docsrc/tutorials/complex_numerics/complex_tensors.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ compilation.
1111
This page explains what the rewriter does, which patterns are supported, and what
1212
limitations to be aware of when compiling models with complex inputs.
1313

14+
.. seealso::
15+
16+
:doc:`../_rendered_examples/dynamo/torch_export_3d_rope` — a runnable
17+
end-to-end example compiling a video-transformer 3D RoPE attention block
18+
(CogVideoX / Wan / HunyuanVideo style) with dynamic T×H×W shapes.
19+
1420
----
1521

1622
How the Rewriter Works
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Complex Numerics
2+
===================
3+
4+
Compatiblity support for numerical datatypes like complex numerics which are not natively supported by TensorRT
5+
6+
.. toctree::
7+
:maxdepth: 1
8+
9+
complex_tensors
10+
Example: 3D RoPE with Complex Numerics <../_rendered_examples/dynamo/torch_export_3d_rope>

docsrc/tutorials/deployment/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,3 @@ complex-valued model support.
1212
cross_compile_windows
1313
Example: Cross-runtime Compilation for Windows <../_rendered_examples/dynamo/cross_runtime_compilation_for_windows>
1414
distributed_inference
15-
complex_tensors

docsrc/tutorials/extensibility/lowering/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ rewrite ATen ops before TensorRT compilation.
88
:maxdepth: 1
99

1010
writing_dynamo_aten_lowering_passes
11+
subgraph_builder
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
.. _subgraph_builder:
2+
3+
SubgraphBuilder — Cursor-Based FX Node Insertion
4+
=================================================
5+
6+
Writing lowering passes that replace one node with several new nodes requires
7+
careful management of insertion order: each new node must be inserted
8+
*after the previous one* so that the topological ordering of the graph is
9+
preserved. Doing this by hand with repeated ``graph.inserting_after(cursor)``
10+
context managers is verbose and error-prone.
11+
12+
``SubgraphBuilder`` is a lightweight context-manager helper in
13+
``torch_tensorrt.dynamo.lowering._SubgraphBuilder`` that automates this
14+
cursor-tracking pattern.
15+
16+
Basic Usage
17+
-----------
18+
19+
Construct a ``SubgraphBuilder`` with the target graph and the *anchor* node —
20+
the node immediately before where you want to start inserting. Then use it
21+
as a callable inside a ``with`` block to add nodes one at a time:
22+
23+
.. code-block:: python
24+
25+
from torch_tensorrt.dynamo.lowering._SubgraphBuilder import SubgraphBuilder
26+
import torch.ops.aten as aten
27+
28+
# Inside a lowering pass, given a node `mul_node` to replace:
29+
with SubgraphBuilder(gm.graph, mul_node) as b:
30+
# Each call inserts a node after the current cursor and advances it.
31+
re_a = b(aten.select.int, a, -1, 0) # a[..., 0] (real part of a)
32+
im_a = b(aten.select.int, a, -1, 1) # a[..., 1] (imag part of a)
33+
re_b = b(aten.select.int, b_node, -1, 0)
34+
im_b = b(aten.select.int, b_node, -1, 1)
35+
real = b(aten.sub.Tensor, b(aten.mul.Tensor, re_a, re_b),
36+
b(aten.mul.Tensor, im_a, im_b)) # ac - bd
37+
imag = b(aten.add.Tensor, b(aten.mul.Tensor, re_a, im_b),
38+
b(aten.mul.Tensor, im_a, re_b)) # ad + bc
39+
result = b(aten.stack, [real, imag], -1)
40+
41+
mul_node.replace_all_uses_with(result)
42+
gm.graph.erase_node(mul_node)
43+
44+
On ``__exit__``, the builder automatically calls ``graph.lint()`` to validate
45+
the modified graph. If your code raises an exception inside the block, the
46+
lint is skipped so you see the original error rather than a secondary graph
47+
validation failure.
48+
49+
How It Works
50+
------------
51+
52+
The builder maintains a *cursor* — initially the anchor node passed to
53+
``__init__``. Every time you call it:
54+
55+
1. A new ``call_function`` node is inserted via ``graph.inserting_after(cursor)``.
56+
2. The cursor advances to the newly inserted node.
57+
3. The new node is appended to an internal ``_inserted`` list for debug logging.
58+
59+
This ensures that successive calls produce a correctly ordered chain:
60+
61+
.. code-block:: text
62+
63+
anchor → node_0 → node_1 → node_2 → ...
64+
65+
without any manual bookkeeping.
66+
67+
Debug Logging
68+
-------------
69+
70+
When the ``torch_tensorrt`` logger is set to ``DEBUG``, the builder emits a
71+
compact summary of all inserted nodes after a successful block, for example::
72+
73+
rewrite %mul_17[(4, 32, 2),torch.float32] ->
74+
%select_72[(4, 32),torch.float32] = select_int(%inp_0, -1, 0)
75+
%select_73[(4, 32),torch.float32] = select_int(%inp_0, -1, 1)
76+
%mul_18[(4, 32),torch.float32] = mul_Tensor(%select_72, %select_73)
77+
...
78+
79+
This makes it easy to trace exactly which nodes were produced by a particular
80+
rewrite rule.
81+
82+
API Reference
83+
-------------
84+
85+
.. autoclass:: torch_tensorrt.dynamo.lowering._SubgraphBuilder.SubgraphBuilder
86+
:members:
87+
:undoc-members:
88+
89+
When to Use SubgraphBuilder
90+
---------------------------
91+
92+
Use ``SubgraphBuilder`` whenever a lowering pass needs to **expand one node into
93+
a sequence of several nodes** in a single linear chain. Typical use cases:
94+
95+
* Replacing a complex-arithmetic op with real-arithmetic equivalents
96+
(e.g. the ``complex_mul_replacement`` in :ref:`complex_number_support_design`).
97+
* Decomposing a high-level op (e.g. ``layer_norm``) into its ATen primitives
98+
when a custom replacement strategy is needed beyond the standard decomposition
99+
table.
100+
* Inserting diagnostic nodes (shape probes, debug prints) around a target op.
101+
102+
If you only need to insert a *single* node, a plain
103+
``graph.inserting_after(node)`` is simpler. If you need to insert into multiple
104+
disconnected locations in the same pass, create a separate ``SubgraphBuilder``
105+
for each anchor.

examples/dynamo/README.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@ Model Zoo
2525
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
2626
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
2727
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)
28-
* :ref:`debugger_example`: Debugging Torch-TensorRT Compilation
28+
* :ref:`debugger_example`: Debugging Torch-TensorRT Compilation
29+
* :ref:`torch_export_3d_rope`: Compiling a 3D RoPE video-transformer block with complex numerics support

0 commit comments

Comments
 (0)