Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 78 additions & 18 deletions modelopt/onnx/quantization/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,34 +324,91 @@ def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[
return tensor_consumer_map


def _is_following_cask_partition(
node: Node, cask_partition_nodes: set[str], max_depth: int = 10
) -> bool:
"""Check if a CASK fusible partition can be reached by traversing backward through copy ops.

Args:
node: The node to check.
cask_partition_nodes: Set of node names belonging to CASK partitions.
max_depth: Maximum recursion depth to guard against pathological graphs.

Returns:
True if the node belongs to or follows a CASK partition through copy ops.
"""
if node.name in cask_partition_nodes:
return True

if max_depth <= 0 or not is_copy_op(node.op):
return False

parent_nodes = get_parent_nodes(node)
if len(parent_nodes) == 0:
return False

return all(
_is_following_cask_partition(parent, cask_partition_nodes, max_depth - 1)
for parent in parent_nodes
)
Comment on lines +346 to +353
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Follow only the data edge through copy ops.

Line 350 currently requires every parent of a copy op to trace back to a CASK partition. That rejects valid Conv -> Reshape/Squeeze/Unsqueeze -> LayerNormalization chains when the shape/axes input is dynamic, because those auxiliary inputs are not part of the activation path. The same false negative now affects filter_quantizable_kgen_heads() too, since it shares this helper.

💡 Proposed fix
-    parent_nodes = get_parent_nodes(node)
-    if len(parent_nodes) == 0:
+    if not node.inputs or not node.inputs[0].inputs:
         return False
 
-    return all(
-        _is_following_cask_partition(parent, cask_partition_nodes, max_depth - 1)
-        for parent in parent_nodes
-    )
+    data_parent = node.inputs[0].inputs[0]
+    return _is_following_cask_partition(data_parent, cask_partition_nodes, max_depth - 1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/quantization/graph_utils.py` around lines 346 - 353, The helper
_is_following_cask_partition incorrectly requires every parent of a copy op to
lead to a CASK partition; instead, when encountering copy ops (e.g., Reshape,
Squeeze, Unsqueeze — the copy/shape-only ops used in activation path), only
follow the data input edge (the activation input, typically input index 0)
rather than all parents. Update the logic where get_parent_nodes(node) is used
in _is_following_cask_partition so that for parents whose op_type is a copy op
you recurse only on the parent node that supplies the data input (e.g., obtain
the input node at index 0 and recurse on that), while leaving non-copy parents
unchanged; this will also fix callers such as filter_quantizable_kgen_heads that
rely on this helper.



def find_conv_to_layernorm_nodes(
graph: Graph,
cask_fusible_partitions: list[list[Node]],
) -> list[Node]:
"""Find LayerNormalization nodes whose input comes from a CASK (Conv) partition.

When a Conv's output feeds into a LayerNormalization, the Conv output should be
quantized to enable faster INT8 kernels in TRT. This function detects such patterns
and returns the LayerNormalization nodes that should be added to the quantizable
nodes list so that Q/DQ pairs are inserted on their input (i.e. the Conv output).

Args:
graph: ONNX model graph.
cask_fusible_partitions: List of CASK fusible partitions.

Returns:
List of LayerNormalization nodes that consume CASK partition outputs.
"""
cask_partition_nodes: set[str] = set()
for partition in cask_fusible_partitions:
cask_partition_nodes.update(node.name for node in partition)

conv_to_ln_nodes = []
for node in graph.nodes:
if node.op != "LayerNormalization":
continue

# Check if the first input (activation) comes from a CASK partition
# possibly through copy ops (Reshape, Transpose, etc.)
inp_tensor = node.inputs[0]
if inp_tensor.inputs:
producer = inp_tensor.inputs[0]
if _is_following_cask_partition(producer, cask_partition_nodes):
conv_to_ln_nodes.append(node)
logger.debug(
f"Found Conv->LayerNorm pattern: LayerNorm node '{node.name}' "
f"consumes CASK partition output"
)

logger.info(f"Found {len(conv_to_ln_nodes)} Conv->LayerNorm patterns to quantize")
return conv_to_ln_nodes


def filter_quantizable_kgen_heads(
cask_fusible_partitions: list[list[Node]],
kgen_partitions: list[list[Node]],
quantizable_op_types: list[str],
graph: Graph,
) -> tuple[list[Node], list[tuple[Node, Node, str]]]:
"""Returns the list of kgen head names if it follows a CASK partition."""
cask_partition_nodes = set()
cask_partition_nodes: set[str] = set()
for partition in cask_fusible_partitions:
cask_partition_nodes.update([node.name for node in partition])
cask_partition_nodes.update(node.name for node in partition)

cask_partition_heads = [partition[0] for partition in cask_fusible_partitions]

def _is_following_cask_partition(node: Node):
# Checking if cask fusible partition can be reached backward
# ignoring the copy ops
if node.name in cask_partition_nodes:
return True

if not is_copy_op(node.op):
return False

parent_nodes = get_parent_nodes(node)
if len(parent_nodes) == 0:
return False

return all(_is_following_cask_partition(parent) for parent in parent_nodes)

def _is_mha_epilogue_pattern(node: Node, graph: Graph):
if head_node.op != "Add":
return False
Expand Down Expand Up @@ -422,7 +479,10 @@ def _has_other_quantizable_consumer(
# and decide which input of kgen head needs quantization
for parent in head_parents:
# If the head is consuming output of any quantizable op, then it is quantizable
if _is_following_cask_partition(parent) or parent.op in output_quantization_candidates:
if (
_is_following_cask_partition(parent, cask_partition_nodes)
or parent.op in output_quantization_candidates
):
# The mask add of MHA should not be quantized
if _is_mha_epilogue_pattern(head_node, graph):
no_quantize_inputs_of_head.append(
Expand Down
7 changes: 6 additions & 1 deletion modelopt/onnx/quantization/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
classify_partition_nodes,
expand_node_names_from_patterns,
filter_quantizable_kgen_heads,
find_conv_to_layernorm_nodes,
find_nodes_from_convs_to_exclude,
find_nodes_from_matmul_to_exclude,
find_nodes_to_exclude,
Expand Down Expand Up @@ -88,12 +89,16 @@ def _find_nodes_to_quantize(
quantizable_op_types,
graph,
)
# Find LayerNormalization nodes fed by Conv (CASK) partitions.
# These need Q/DQ on their input to enable faster INT8 kernels in TRT.
conv_to_ln_nodes = find_conv_to_layernorm_nodes(graph, cask_fusible_partitions)

logger.info(
f"Found {len(quantizable_partition_nodes)} quantizable partition "
f"nodes and {len(quantizable_kgen_heads)} quantizable KGEN heads"
)

quantizable_nodes = quantizable_kgen_heads + quantizable_partition_nodes
quantizable_nodes = quantizable_kgen_heads + quantizable_partition_nodes + conv_to_ln_nodes
partially_quantizable_nodes = [dst for _, dst, _ in no_quantize_inputs]
# Quantize all inputs of partially quantizable nodes by ORT
# but remove QDQ from non-quantizable inputs in the post-processing step
Expand Down
1 change: 1 addition & 0 deletions modelopt/onnx/quantization/ort_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def configure_ort(
logger.debug("Registering custom QDQ operators")
QDQRegistry["BatchNormalization"] = QDQNormalization
QDQRegistry["ConvTranspose"] = QDQConvTranspose
QDQRegistry["LayerNormalization"] = QDQNormalization # Conv->LayerNorm quantization
QDQRegistry["LRN"] = QDQNormalization # Example: caffenet-12.onnx
QDQRegistry["HardSwish"] = (
QDQOperatorBase # Example: mobilenet_v3_opset17, efficientvit_b3_opset17
Expand Down
115 changes: 115 additions & 0 deletions tests/_test_utils/onnx/lib_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,3 +1009,118 @@ def build_conv_resize_model():
onnx.checker.check_model(model_inferred)

return model_inferred


def build_conv_layernorm_model():
"""Build a ConvNext-like model with Conv -> Transpose -> LayerNorm pattern.

This creates a simplified ConvNext block:
Conv -> Transpose(NCHW->NHWC) -> LayerNorm -> Transpose(NHWC->NCHW) -> Conv
"""
channels = 32
input_names = ["input_0"]
output_names = ["output_0"]
input_shapes = [(1, 3, 56, 56)]
output_shapes = [(1, channels, 14, 14)]

inputs = [
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
for input_name, input_shape in zip(input_names, input_shapes)
]
outputs = [
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
for output_name, output_shape in zip(output_names, output_shapes)
]

nodes = [
# Stem Conv: 3 -> channels with stride 4
helper.make_node(
op_type="Conv",
inputs=["input_0", "stem_conv_w", "stem_conv_b"],
outputs=["stem_conv_out"],
name="stem_conv",
kernel_shape=[4, 4],
strides=[4, 4],
),
# Transpose NCHW -> NHWC for LayerNorm
helper.make_node(
op_type="Transpose",
inputs=["stem_conv_out"],
outputs=["stem_transpose1_out"],
name="stem_transpose1",
perm=[0, 2, 3, 1],
),
# LayerNorm over last axis (channels)
helper.make_node(
op_type="LayerNormalization",
inputs=["stem_transpose1_out", "stem_ln_scale", "stem_ln_bias"],
outputs=["stem_ln_out"],
name="stem_ln",
axis=-1,
epsilon=1e-6,
),
# Transpose NHWC -> NCHW
helper.make_node(
op_type="Transpose",
inputs=["stem_ln_out"],
outputs=["stem_transpose2_out"],
name="stem_transpose2",
perm=[0, 3, 1, 2],
),
# Second Conv to produce output
helper.make_node(
op_type="Conv",
inputs=["stem_transpose2_out", "conv2_w", "conv2_b"],
outputs=["output_0"],
name="conv2",
kernel_shape=[1, 1],
),
]

initializers = [
helper.make_tensor(
"stem_conv_w",
onnx.TensorProto.FLOAT,
[channels, 3, 4, 4],
np.random.randn(channels * 3 * 4 * 4).astype(np.float32).tolist(),
),
helper.make_tensor(
"stem_conv_b",
onnx.TensorProto.FLOAT,
[channels],
np.random.randn(channels).astype(np.float32).tolist(),
),
helper.make_tensor(
"stem_ln_scale",
onnx.TensorProto.FLOAT,
[channels],
np.ones(channels).astype(np.float32).tolist(),
),
helper.make_tensor(
"stem_ln_bias",
onnx.TensorProto.FLOAT,
[channels],
np.zeros(channels).astype(np.float32).tolist(),
),
helper.make_tensor(
"conv2_w",
onnx.TensorProto.FLOAT,
[channels, channels, 1, 1],
np.random.randn(channels * channels * 1 * 1).astype(np.float32).tolist(),
),
helper.make_tensor(
"conv2_b",
onnx.TensorProto.FLOAT,
[channels],
np.random.randn(channels).astype(np.float32).tolist(),
),
]

graph = helper.make_graph(nodes, "conv_layernorm", inputs, outputs, initializer=initializers)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
model.ir_version = 8

model_inferred = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(model_inferred)

return model_inferred
39 changes: 39 additions & 0 deletions tests/unit/onnx/quantization/test_qdq_rules_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
build_conv_act_pool_model,
build_conv_batchnorm_sig_mul_model,
build_conv_isinf_model,
build_conv_layernorm_model,
build_convtranspose_conv_residual_model,
build_r1a_model,
build_resnet_block,
Expand Down Expand Up @@ -243,3 +244,41 @@ def test_conv_isinf_int8(tmp_path):
assert inp.dtype == supported_dtype, (
f"Node of type {node.op} has type {inp.dtype} but should have type {supported_dtype}"
)


def test_conv_layernorm_quantization(tmp_path):
"""Test that Conv -> LayerNorm pattern gets Q/DQ on the Conv output.

Bug 5271237: ModelOpt should detect Conv -> LayerNorm pattern and quantize
the Conv output (LayerNorm input) to enable faster INT8 kernels in TRT.
"""
model = build_conv_layernorm_model()
onnx_path = os.path.join(tmp_path, "model.onnx")
onnx.save(model, onnx_path)

# Quantize the input model
quantize(onnx_path)

output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
assert os.path.isfile(output_onnx_path)

# Load the output model and check QDQ node placements
graph = gs.import_onnx(onnx.load(output_onnx_path))

# Check that Conv nodes are quantized (inputs have Q/DQ)
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
assert assert_nodes_are_quantized(conv_nodes)

# Check that LayerNormalization has Q/DQ on its activation input
ln_nodes = [n for n in graph.nodes if n.op == "LayerNormalization"]
assert len(ln_nodes) == 1, f"Expected 1 LayerNorm node, found {len(ln_nodes)}"

ln_node = ln_nodes[0]
# The activation input (input[0]) should come from a DequantizeLinear node
activation_input = ln_node.inputs[0]
assert activation_input.inputs, "LayerNorm activation input has no producer"
producer = activation_input.inputs[0]
assert producer.op == "DequantizeLinear", (
f"LayerNorm activation input should come from DequantizeLinear, "
f"but comes from {producer.op}. Conv->LayerNorm output quantization is missing!"
)
Loading