diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 131723e61..4c026b625 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -324,6 +324,78 @@ def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[ return tensor_consumer_map +def _is_following_cask_partition( + node: Node, cask_partition_nodes: set[str], max_depth: int = 10 +) -> bool: + """Check if a CASK fusible partition can be reached by traversing backward through copy ops. + + Args: + node: The node to check. + cask_partition_nodes: Set of node names belonging to CASK partitions. + max_depth: Maximum recursion depth to guard against pathological graphs. + + Returns: + True if the node belongs to or follows a CASK partition through copy ops. + """ + if node.name in cask_partition_nodes: + return True + + if max_depth <= 0 or not is_copy_op(node.op): + return False + + parent_nodes = get_parent_nodes(node) + if len(parent_nodes) == 0: + return False + + return all( + _is_following_cask_partition(parent, cask_partition_nodes, max_depth - 1) + for parent in parent_nodes + ) + + +def find_conv_to_layernorm_nodes( + graph: Graph, + cask_fusible_partitions: list[list[Node]], +) -> list[Node]: + """Find LayerNormalization nodes whose input comes from a CASK (Conv) partition. + + When a Conv's output feeds into a LayerNormalization, the Conv output should be + quantized to enable faster INT8 kernels in TRT. This function detects such patterns + and returns the LayerNormalization nodes that should be added to the quantizable + nodes list so that Q/DQ pairs are inserted on their input (i.e. the Conv output). + + Args: + graph: ONNX model graph. + cask_fusible_partitions: List of CASK fusible partitions. + + Returns: + List of LayerNormalization nodes that consume CASK partition outputs. + """ + cask_partition_nodes: set[str] = set() + for partition in cask_fusible_partitions: + cask_partition_nodes.update(node.name for node in partition) + + conv_to_ln_nodes = [] + for node in graph.nodes: + if node.op != "LayerNormalization": + continue + + # Check if the first input (activation) comes from a CASK partition + # possibly through copy ops (Reshape, Transpose, etc.) + inp_tensor = node.inputs[0] + if inp_tensor.inputs: + producer = inp_tensor.inputs[0] + if _is_following_cask_partition(producer, cask_partition_nodes): + conv_to_ln_nodes.append(node) + logger.debug( + f"Found Conv->LayerNorm pattern: LayerNorm node '{node.name}' " + f"consumes CASK partition output" + ) + + logger.info(f"Found {len(conv_to_ln_nodes)} Conv->LayerNorm patterns to quantize") + return conv_to_ln_nodes + + def filter_quantizable_kgen_heads( cask_fusible_partitions: list[list[Node]], kgen_partitions: list[list[Node]], @@ -331,27 +403,12 @@ def filter_quantizable_kgen_heads( graph: Graph, ) -> tuple[list[Node], list[tuple[Node, Node, str]]]: """Returns the list of kgen head names if it follows a CASK partition.""" - cask_partition_nodes = set() + cask_partition_nodes: set[str] = set() for partition in cask_fusible_partitions: - cask_partition_nodes.update([node.name for node in partition]) + cask_partition_nodes.update(node.name for node in partition) cask_partition_heads = [partition[0] for partition in cask_fusible_partitions] - def _is_following_cask_partition(node: Node): - # Checking if cask fusible partition can be reached backward - # ignoring the copy ops - if node.name in cask_partition_nodes: - return True - - if not is_copy_op(node.op): - return False - - parent_nodes = get_parent_nodes(node) - if len(parent_nodes) == 0: - return False - - return all(_is_following_cask_partition(parent) for parent in parent_nodes) - def _is_mha_epilogue_pattern(node: Node, graph: Graph): if head_node.op != "Add": return False @@ -422,7 +479,10 @@ def _has_other_quantizable_consumer( # and decide which input of kgen head needs quantization for parent in head_parents: # If the head is consuming output of any quantizable op, then it is quantizable - if _is_following_cask_partition(parent) or parent.op in output_quantization_candidates: + if ( + _is_following_cask_partition(parent, cask_partition_nodes) + or parent.op in output_quantization_candidates + ): # The mask add of MHA should not be quantized if _is_mha_epilogue_pattern(head_node, graph): no_quantize_inputs_of_head.append( diff --git a/modelopt/onnx/quantization/int8.py b/modelopt/onnx/quantization/int8.py index ad2ca9558..5b1ad5efe 100755 --- a/modelopt/onnx/quantization/int8.py +++ b/modelopt/onnx/quantization/int8.py @@ -35,6 +35,7 @@ classify_partition_nodes, expand_node_names_from_patterns, filter_quantizable_kgen_heads, + find_conv_to_layernorm_nodes, find_nodes_from_convs_to_exclude, find_nodes_from_matmul_to_exclude, find_nodes_to_exclude, @@ -88,12 +89,16 @@ def _find_nodes_to_quantize( quantizable_op_types, graph, ) + # Find LayerNormalization nodes fed by Conv (CASK) partitions. + # These need Q/DQ on their input to enable faster INT8 kernels in TRT. + conv_to_ln_nodes = find_conv_to_layernorm_nodes(graph, cask_fusible_partitions) + logger.info( f"Found {len(quantizable_partition_nodes)} quantizable partition " f"nodes and {len(quantizable_kgen_heads)} quantizable KGEN heads" ) - quantizable_nodes = quantizable_kgen_heads + quantizable_partition_nodes + quantizable_nodes = quantizable_kgen_heads + quantizable_partition_nodes + conv_to_ln_nodes partially_quantizable_nodes = [dst for _, dst, _ in no_quantize_inputs] # Quantize all inputs of partially quantizable nodes by ORT # but remove QDQ from non-quantizable inputs in the post-processing step diff --git a/modelopt/onnx/quantization/ort_utils.py b/modelopt/onnx/quantization/ort_utils.py index 173fbb06d..7d90f613c 100755 --- a/modelopt/onnx/quantization/ort_utils.py +++ b/modelopt/onnx/quantization/ort_utils.py @@ -280,6 +280,7 @@ def configure_ort( logger.debug("Registering custom QDQ operators") QDQRegistry["BatchNormalization"] = QDQNormalization QDQRegistry["ConvTranspose"] = QDQConvTranspose + QDQRegistry["LayerNormalization"] = QDQNormalization # Conv->LayerNorm quantization QDQRegistry["LRN"] = QDQNormalization # Example: caffenet-12.onnx QDQRegistry["HardSwish"] = ( QDQOperatorBase # Example: mobilenet_v3_opset17, efficientvit_b3_opset17 diff --git a/tests/_test_utils/onnx/lib_test_models.py b/tests/_test_utils/onnx/lib_test_models.py index ff97b5142..97bc22b71 100644 --- a/tests/_test_utils/onnx/lib_test_models.py +++ b/tests/_test_utils/onnx/lib_test_models.py @@ -1009,3 +1009,118 @@ def build_conv_resize_model(): onnx.checker.check_model(model_inferred) return model_inferred + + +def build_conv_layernorm_model(): + """Build a ConvNext-like model with Conv -> Transpose -> LayerNorm pattern. + + This creates a simplified ConvNext block: + Conv -> Transpose(NCHW->NHWC) -> LayerNorm -> Transpose(NHWC->NCHW) -> Conv + """ + channels = 32 + input_names = ["input_0"] + output_names = ["output_0"] + input_shapes = [(1, 3, 56, 56)] + output_shapes = [(1, channels, 14, 14)] + + inputs = [ + helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape) + for input_name, input_shape in zip(input_names, input_shapes) + ] + outputs = [ + helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape) + for output_name, output_shape in zip(output_names, output_shapes) + ] + + nodes = [ + # Stem Conv: 3 -> channels with stride 4 + helper.make_node( + op_type="Conv", + inputs=["input_0", "stem_conv_w", "stem_conv_b"], + outputs=["stem_conv_out"], + name="stem_conv", + kernel_shape=[4, 4], + strides=[4, 4], + ), + # Transpose NCHW -> NHWC for LayerNorm + helper.make_node( + op_type="Transpose", + inputs=["stem_conv_out"], + outputs=["stem_transpose1_out"], + name="stem_transpose1", + perm=[0, 2, 3, 1], + ), + # LayerNorm over last axis (channels) + helper.make_node( + op_type="LayerNormalization", + inputs=["stem_transpose1_out", "stem_ln_scale", "stem_ln_bias"], + outputs=["stem_ln_out"], + name="stem_ln", + axis=-1, + epsilon=1e-6, + ), + # Transpose NHWC -> NCHW + helper.make_node( + op_type="Transpose", + inputs=["stem_ln_out"], + outputs=["stem_transpose2_out"], + name="stem_transpose2", + perm=[0, 3, 1, 2], + ), + # Second Conv to produce output + helper.make_node( + op_type="Conv", + inputs=["stem_transpose2_out", "conv2_w", "conv2_b"], + outputs=["output_0"], + name="conv2", + kernel_shape=[1, 1], + ), + ] + + initializers = [ + helper.make_tensor( + "stem_conv_w", + onnx.TensorProto.FLOAT, + [channels, 3, 4, 4], + np.random.randn(channels * 3 * 4 * 4).astype(np.float32).tolist(), + ), + helper.make_tensor( + "stem_conv_b", + onnx.TensorProto.FLOAT, + [channels], + np.random.randn(channels).astype(np.float32).tolist(), + ), + helper.make_tensor( + "stem_ln_scale", + onnx.TensorProto.FLOAT, + [channels], + np.ones(channels).astype(np.float32).tolist(), + ), + helper.make_tensor( + "stem_ln_bias", + onnx.TensorProto.FLOAT, + [channels], + np.zeros(channels).astype(np.float32).tolist(), + ), + helper.make_tensor( + "conv2_w", + onnx.TensorProto.FLOAT, + [channels, channels, 1, 1], + np.random.randn(channels * channels * 1 * 1).astype(np.float32).tolist(), + ), + helper.make_tensor( + "conv2_b", + onnx.TensorProto.FLOAT, + [channels], + np.random.randn(channels).astype(np.float32).tolist(), + ), + ] + + graph = helper.make_graph(nodes, "conv_layernorm", inputs, outputs, initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 8 + + model_inferred = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model_inferred) + + return model_inferred diff --git a/tests/unit/onnx/quantization/test_qdq_rules_int8.py b/tests/unit/onnx/quantization/test_qdq_rules_int8.py index 43d6e4a4e..5c4648c70 100644 --- a/tests/unit/onnx/quantization/test_qdq_rules_int8.py +++ b/tests/unit/onnx/quantization/test_qdq_rules_int8.py @@ -23,6 +23,7 @@ build_conv_act_pool_model, build_conv_batchnorm_sig_mul_model, build_conv_isinf_model, + build_conv_layernorm_model, build_convtranspose_conv_residual_model, build_r1a_model, build_resnet_block, @@ -243,3 +244,41 @@ def test_conv_isinf_int8(tmp_path): assert inp.dtype == supported_dtype, ( f"Node of type {node.op} has type {inp.dtype} but should have type {supported_dtype}" ) + + +def test_conv_layernorm_quantization(tmp_path): + """Test that Conv -> LayerNorm pattern gets Q/DQ on the Conv output. + + Bug 5271237: ModelOpt should detect Conv -> LayerNorm pattern and quantize + the Conv output (LayerNorm input) to enable faster INT8 kernels in TRT. + """ + model = build_conv_layernorm_model() + onnx_path = os.path.join(tmp_path, "model.onnx") + onnx.save(model, onnx_path) + + # Quantize the input model + quantize(onnx_path) + + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + assert os.path.isfile(output_onnx_path) + + # Load the output model and check QDQ node placements + graph = gs.import_onnx(onnx.load(output_onnx_path)) + + # Check that Conv nodes are quantized (inputs have Q/DQ) + conv_nodes = [n for n in graph.nodes if n.op == "Conv"] + assert assert_nodes_are_quantized(conv_nodes) + + # Check that LayerNormalization has Q/DQ on its activation input + ln_nodes = [n for n in graph.nodes if n.op == "LayerNormalization"] + assert len(ln_nodes) == 1, f"Expected 1 LayerNorm node, found {len(ln_nodes)}" + + ln_node = ln_nodes[0] + # The activation input (input[0]) should come from a DequantizeLinear node + activation_input = ln_node.inputs[0] + assert activation_input.inputs, "LayerNorm activation input has no producer" + producer = activation_input.inputs[0] + assert producer.op == "DequantizeLinear", ( + f"LayerNorm activation input should come from DequantizeLinear, " + f"but comes from {producer.op}. Conv->LayerNorm output quantization is missing!" + )