@@ -146,8 +146,218 @@ Key Implementation Invariants
146146 Nested submodule parameter names (e.g. ``layers.0.weight ``) must have ``. ``
147147 replaced with ``__ `` before registration.
148148
149+ The Decomposition System — How It Is Built
150+ -------------------------------------------
151+
152+ The rewriter is split across two classes and wired together by a lightweight
153+ dispatch mechanism. This section walks through each piece and explains the
154+ design decisions.
155+
156+ ComplexOpDetector — Subgraph Discovery
157+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
158+
159+ ``ComplexOpDetector `` walks the graph to find the set of nodes that participate
160+ in complex arithmetic.
161+
162+ ``node_include_in_subgraph ``
163+ """"""""""""""""""""""""""""
164+
165+ A node is included in a complex subgraph if:
166+
167+ 1. Its output dtype is ``complex64 `` or ``complex128 `` (``is_complex_dtype ``), **or **
168+ 2. Any of its inputs are complex (``has_complex_input ``).
169+
170+ The second condition is necessary to catch real-output ops — ``abs ``, ``angle ``,
171+ ``real ``, ``imag `` — whose inputs are complex. These must be rewritten alongside
172+ the rest of the subgraph even though their outputs are real.
173+
174+ ``subgraph_from_anchor ``
175+ """"""""""""""""""""""""
176+
177+ For ``view_as_real ``-bounded subgraphs, detection starts at a ``view_as_real ``
178+ *anchor * node and performs a backward BFS:
179+
180+ .. code-block :: text
181+
182+ view_as_real ← mul (complex) ← reshape ← placeholder (complex)
183+ ↑ anchor ↑ subgraph ↑ subgraph ↑ input
184+
185+ At each step, if an upstream node satisfies ``node_include_in_subgraph `` it is
186+ added to the subgraph; otherwise it becomes an *input node * (the boundary). The
187+ result is a ``ComplexSubGraphInfo `` containing anchor nodes, subgraph nodes, and
188+ input nodes.
189+
190+ After collection the subgraph is **sorted in topological order ** (by position in
191+ the graph's node list). This is critical: without it a ``mul `` node could be
192+ processed before its ``sin `` or ``cos `` operands, causing the rewriter to see the
193+ original complex node instead of the already-rewritten real node.
194+
195+ ``find_complex_op_subgraphs `` and subgraph merging
196+ """""""""""""""""""""""""""""""""""""""""""""""""""
197+
198+ When a model has multiple ``view_as_real `` anchors that share upstream nodes
199+ (e.g. ``xq_out `` and ``xk_out `` in a RoPE layer both descend from the same
200+ ``freqs_cis `` placeholder), their subgraphs would otherwise be detected
201+ separately. ``find_complex_op_subgraphs `` merges overlapping subgraphs by
202+ set intersection so each node is rewritten exactly once.
203+
204+ ``find_all_complex_subgraphs `` — unbounded complex ops
205+ """""""""""""""""""""""""""""""""""""""""""""""""""""""
206+
207+ Some models produce a complex tensor as a graph *output * without passing it
208+ through ``view_as_real ``. ``find_all_complex_subgraphs `` is a forward scan that
209+ collects every ``call_function `` node with a complex output, regardless of
210+ anchoring. The resulting subgraph is processed the same way as an
211+ anchor-bounded one.
212+
213+ ComplexGraphRewriter — Dispatch-Based Rewriting
214+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
215+
216+ ``ComplexGraphRewriter `` is decorated with ``@_register_unpackers ``, which at
217+ class-definition time scans every method for the ``@_complex_unpacker(op, ...) ``
218+ decorator and builds a ``cls._DISPATCH `` dictionary mapping aten ops to rewrite
219+ methods.
220+
221+ .. code-block :: python
222+
223+ @_complex_unpacker (torch.ops.aten.mul.Tensor)
224+ def _rewrite_mul (self , node : Node, b : SubgraphBuilder, ...):
225+ ...
226+
227+ The entry point ``rewrite_subgraph_nodes `` iterates over the (topologically
228+ ordered) subgraph nodes and for each node:
229+
230+ 1. Looks up ``node.target `` in ``_DISPATCH ``.
231+ 2. If found, calls the corresponding rewrite method.
232+ 3. If not found but the op is in ``_ELEMENTWISE_SAFE ``, skips it (the op applies
233+ independently to every scalar, so the ``(..., 2) `` real layout is already
234+ correct).
235+ 4. Otherwise logs a warning and leaves the node unchanged.
236+
237+ ``_ELEMENTWISE_SAFE ``
238+ """""""""""""""""""""
239+
240+ The ``_ELEMENTWISE_SAFE `` set contains ops that apply to every element of the
241+ tensor independently — ``add.Tensor ``, ``sub.Tensor ``, ``neg ``, ``mul.Scalar ``,
242+ ``clone ``, ``where ``, etc. On the ``(..., 2) `` real layout these are already
243+ correct: adding two complex tensors element-wise is the same as adding their
244+ real and imaginary parts independently.
245+
246+ Notably **excluded ** from this set:
247+
248+ * ``permute.default `` — must append the trailing real/imag dim index.
249+ * ``add.Scalar `` / ``sub.Scalar `` — a scalar added to a complex number only
250+ shifts the real part; on the ``(..., 2) `` layout both parts would be shifted.
251+ * ``reshape `` / ``view `` — shape arguments need updating for the extra ``2 `` dim.
252+
253+ Complex Multiply Decomposition
254+ """""""""""""""""""""""""""""""
255+
256+ The most important rewrite is ``mul.Tensor `` between two complex operands.
257+ The rewriter calls ``complex_mul_replacement ``:
258+
259+ .. code-block :: python
260+
261+ # inputs a, b have shape (..., 2) — last dim is [real, imag]
262+ re_a = select(a, - 1 , 0 ); im_a = select(a, - 1 , 1 )
263+ re_b = select(b, - 1 , 0 ); im_b = select(b, - 1 , 1 )
264+ real_out = re_a * re_b - im_a * im_b # ac - bd
265+ imag_out = re_a * im_b + im_a * re_b # ad + bc
266+ result = stack([real_out, imag_out], dim = - 1 )
267+
268+ Each step is inserted via a ``SubgraphBuilder `` anchored at the ``mul `` node,
269+ so all six new nodes appear immediately after it in topological order. The
270+ original ``mul `` node is then replaced and erased.
271+
272+ See :ref: `subgraph_builder ` for more on how ``SubgraphBuilder `` manages
273+ cursor-based insertion.
274+
275+ The ``originally_complex `` Invariant
276+ """""""""""""""""""""""""""""""""""""
277+
278+ Input replacement (Stage 2) converts complex ``placeholder `` nodes to
279+ ``float32 ``. After that, ``is_complex_dtype(node) `` returns ``False `` for those
280+ nodes even though they logically represent complex quantities.
281+
282+ To avoid missed rewrites, the rewriter records the set of nodes that were complex
283+ *before any rewrites * in ``originally_complex ``. The ``mul.Tensor `` dispatch
284+ handler only triggers the full complex-multiply decomposition when the ``mul ``
285+ node appears in ``originally_complex ``; real multiplies that happen to follow a
286+ complex input (e.g. an ``abs `` followed by a real-valued scale) are left alone.
287+
288+ FakeTensorMode Reuse for Dynamic Shapes
289+ """""""""""""""""""""""""""""""""""""""""
290+
291+ When inserting a new ``placeholder `` for a complex input, the pass must populate
292+ ``meta["val"] `` with a ``FakeTensor `` of the new real shape. Using a fresh
293+ ``FakeTensorMode() `` would create a *new * ``ShapeEnv ``, which is incompatible
294+ with the one that ``torch.export `` used to encode dynamic shape constraints
295+ (SymInt ranges).
296+
297+ The fix is to extract the ``FakeTensorMode `` from the *original * placeholder's
298+ ``meta["val"].fake_mode `` and reuse it. The new fake tensor is then constructed
299+ by appending a concrete ``2 `` to the symbolic shape list:
300+
301+ .. code-block :: python
302+
303+ orig_fake = input_node.meta[" val" ]
304+ sym_shape = list (orig_fake.shape) + [2 ]
305+ with orig_fake.fake_mode:
306+ fake_tensor = torch.empty(sym_shape, dtype = new_dtype, device = device)
307+
308+ This preserves all SymInt identity across the graph and keeps
309+ dynamic-shape exports working correctly.
310+
311+ Entry Point: ``complex_graph_detection ``
312+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
313+
314+ The public entry point called by the lowering pipeline is
315+ ``complex_graph_detection(gm, settings) ``. It:
316+
317+ 1. Instantiates ``ComplexOpDetector `` and ``ComplexGraphRewriter ``.
318+ 2. Calls ``find_complex_op_subgraphs `` anchored on ``view_as_real `` to find
319+ bounded complex subgraphs.
320+ 3. Calls ``find_all_complex_subgraphs `` for any remaining complex nodes that
321+ are not ``view_as_real ``-bounded.
322+ 4. For each subgraph:
323+
324+ a. Calls ``replace_input_node `` on every boundary input node (Stage 2).
325+ b. Calls ``rewrite_subgraph_nodes `` on the ordered subgraph (Stage 3).
326+ c. Calls ``clean_up_graph_after_modifications `` to remove dead nodes.
327+
328+ 5. Returns the modified ``GraphModule ``.
329+
330+ Adding New Op Rewrites
331+ ^^^^^^^^^^^^^^^^^^^^^^^
332+
333+ To teach the rewriter about a new complex op, add a method to
334+ ``ComplexGraphRewriter `` tagged with ``@_complex_unpacker ``:
335+
336+ .. code-block :: python
337+
338+ @_complex_unpacker (torch.ops.aten.my_new_op.default)
339+ def _rewrite_my_new_op (
340+ self ,
341+ node : Node,
342+ originally_complex : set ,
343+ ) -> None :
344+ inp = node.args[0 ]
345+ with SubgraphBuilder(self .gm.graph, node) as b:
346+ re = b(torch.ops.aten.select.int, inp, - 1 , 0 )
347+ im = b(torch.ops.aten.select.int, inp, - 1 , 1 )
348+ result = b(my_real_impl, re, im)
349+ node.replace_all_uses_with(result)
350+ self .gm.graph.erase_node(node)
351+
352+ ``@_register_unpackers `` (applied to the class) picks up the new entry
353+ automatically at import time — no other registration is required.
354+
355+ If the new op is elementwise-safe on the ``(..., 2) `` layout (i.e. it acts
356+ independently on every scalar), add it to ``_ELEMENTWISE_SAFE `` instead.
357+
149358Related
150359-------
151360
152361* :ref: `lowering ` — the complex rewrite is a lowering pass.
362+ * :ref: `subgraph_builder ` — the ``SubgraphBuilder `` helper used in every rewrite method.
153363* :ref: `lowering_passes_catalog ` — pass ordering and management.
0 commit comments