Skip to content

Commit 523da9d

Browse files
committed
Reuse _is_following_cask_partition
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 34139ec commit 523da9d

File tree

1 file changed

+47
-48
lines changed

1 file changed

+47
-48
lines changed

modelopt/onnx/quantization/graph_utils.py

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,35 @@ def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[
324324
return tensor_consumer_map
325325

326326

327+
def _is_following_cask_partition(
328+
node: Node, cask_partition_nodes: set[str], max_depth: int = 10
329+
) -> bool:
330+
"""Check if a CASK fusible partition can be reached by traversing backward through copy ops.
331+
332+
Args:
333+
node: The node to check.
334+
cask_partition_nodes: Set of node names belonging to CASK partitions.
335+
max_depth: Maximum recursion depth to guard against pathological graphs.
336+
337+
Returns:
338+
True if the node belongs to or follows a CASK partition through copy ops.
339+
"""
340+
if node.name in cask_partition_nodes:
341+
return True
342+
343+
if max_depth <= 0 or not is_copy_op(node.op):
344+
return False
345+
346+
parent_nodes = get_parent_nodes(node)
347+
if len(parent_nodes) == 0:
348+
return False
349+
350+
return all(
351+
_is_following_cask_partition(parent, cask_partition_nodes, max_depth - 1)
352+
for parent in parent_nodes
353+
)
354+
355+
327356
def find_conv_to_layernorm_nodes(
328357
graph: Graph,
329358
cask_fusible_partitions: list[list[Node]],
@@ -342,77 +371,44 @@ def find_conv_to_layernorm_nodes(
342371
Returns:
343372
List of LayerNormalization nodes that consume CASK partition outputs.
344373
"""
345-
# Collect the output tensor names from CASK partitions
346-
cask_output_tensor_names = set()
374+
cask_partition_nodes: set[str] = set()
347375
for partition in cask_fusible_partitions:
348-
last_node = partition[-1]
349-
for output_tensor in last_node.outputs:
350-
cask_output_tensor_names.add(output_tensor.name)
376+
cask_partition_nodes.update(node.name for node in partition)
351377

352378
conv_to_ln_nodes = []
353379
for node in graph.nodes:
354380
if node.op != "LayerNormalization":
355381
continue
356382

357-
# Check if the first input (activation) comes from a CASK partition output
383+
# Check if the first input (activation) comes from a CASK partition
358384
# possibly through copy ops (Reshape, Transpose, etc.)
359-
if _is_input_from_cask_partition(node.inputs[0], cask_output_tensor_names):
360-
conv_to_ln_nodes.append(node)
361-
logger.debug(
362-
f"Found Conv->LayerNorm pattern: LayerNorm node '{node.name}' "
363-
f"consumes CASK partition output"
364-
)
385+
inp_tensor = node.inputs[0]
386+
if inp_tensor.inputs:
387+
producer = inp_tensor.inputs[0]
388+
if _is_following_cask_partition(producer, cask_partition_nodes):
389+
conv_to_ln_nodes.append(node)
390+
logger.debug(
391+
f"Found Conv->LayerNorm pattern: LayerNorm node '{node.name}' "
392+
f"consumes CASK partition output"
393+
)
365394

366395
logger.info(f"Found {len(conv_to_ln_nodes)} Conv->LayerNorm patterns to quantize")
367396
return conv_to_ln_nodes
368397

369398

370-
def _is_input_from_cask_partition(tensor: Tensor, cask_output_tensor_names: set[str]) -> bool:
371-
"""Check if a tensor originates from a CASK partition output, traversing through copy ops."""
372-
if tensor.name in cask_output_tensor_names:
373-
return True
374-
375-
# Traverse backward through copy ops (Reshape, Transpose, etc.)
376-
if tensor.inputs:
377-
producer = tensor.inputs[0]
378-
if is_copy_op(producer.op):
379-
for inp in producer.inputs:
380-
if not isinstance(inp, Constant) and _is_input_from_cask_partition(
381-
inp, cask_output_tensor_names
382-
):
383-
return True
384-
385-
return False
386-
387-
388399
def filter_quantizable_kgen_heads(
389400
cask_fusible_partitions: list[list[Node]],
390401
kgen_partitions: list[list[Node]],
391402
quantizable_op_types: list[str],
392403
graph: Graph,
393404
) -> tuple[list[Node], list[tuple[Node, Node, str]]]:
394405
"""Returns the list of kgen head names if it follows a CASK partition."""
395-
cask_partition_nodes = set()
406+
cask_partition_nodes: set[str] = set()
396407
for partition in cask_fusible_partitions:
397-
cask_partition_nodes.update([node.name for node in partition])
408+
cask_partition_nodes.update(node.name for node in partition)
398409

399410
cask_partition_heads = [partition[0] for partition in cask_fusible_partitions]
400411

401-
def _is_following_cask_partition(node: Node):
402-
# Checking if cask fusible partition can be reached backward
403-
# ignoring the copy ops
404-
if node.name in cask_partition_nodes:
405-
return True
406-
407-
if not is_copy_op(node.op):
408-
return False
409-
410-
parent_nodes = get_parent_nodes(node)
411-
if len(parent_nodes) == 0:
412-
return False
413-
414-
return all(_is_following_cask_partition(parent) for parent in parent_nodes)
415-
416412
def _is_mha_epilogue_pattern(node: Node, graph: Graph):
417413
if head_node.op != "Add":
418414
return False
@@ -483,7 +479,10 @@ def _has_other_quantizable_consumer(
483479
# and decide which input of kgen head needs quantization
484480
for parent in head_parents:
485481
# If the head is consuming output of any quantizable op, then it is quantizable
486-
if _is_following_cask_partition(parent) or parent.op in output_quantization_candidates:
482+
if (
483+
_is_following_cask_partition(parent, cask_partition_nodes)
484+
or parent.op in output_quantization_candidates
485+
):
487486
# The mask add of MHA should not be quantized
488487
if _is_mha_epilogue_pattern(head_node, graph):
489488
no_quantize_inputs_of_head.append(

0 commit comments

Comments
 (0)