diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 131723e61..d80ac68e2 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -1082,12 +1082,17 @@ def find_nodes_from_matmul_to_exclude( return [*set(nodes_to_exclude)] +_MIN_CHANNELS_FP8 = 16 + + def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"): """Find unsupported Conv nodes to exclude from quantization. - The input and output channels should be >= 16. The exception is for Conv layers in INT8 quantization mode, which supports it if the input or output channel % 8. - The filter size for FP8 conv kernels should be less than 32. + - For FP8 mode, Conv nodes with input or output channels <= _MIN_CHANNELS_FP8 are excluded. + Small-channel convolutions do not benefit from FP8 quantization. Args: graph: Onnx model graph. @@ -1147,6 +1152,20 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"): if quantize_mode == "fp8" and filter_size > 32: logger.debug(f"Found large filter conv for FP8: {node.name}") unsupported_conv_nodes.append(node.name) + # skip the small-channel check below; already excluded + continue + + # For FP8, exclude small-channel convolutions. These layers do not benefit from + # FP8 quantization and cause perf regressions on GPUs where the FP8 conv kernels + # are slower than FP16 CASK kernels for small channels. + if quantize_mode == "fp8" and ( + output_channel <= _MIN_CHANNELS_FP8 or input_channel <= _MIN_CHANNELS_FP8 + ): + logger.debug( + f"Excluding small-channel Conv from FP8 quantization: {node.name} " + f"(IC={input_channel}, OC={output_channel}, threshold={_MIN_CHANNELS_FP8})" + ) + unsupported_conv_nodes.append(node.name) logger.info(f"Found {len(unsupported_conv_nodes)} unsupported Conv nodes for quantization") return unsupported_conv_nodes diff --git a/tests/unit/onnx/quantization/test_graph_utils.py b/tests/unit/onnx/quantization/test_graph_utils.py new file mode 100644 index 000000000..1deaa1b8d --- /dev/null +++ b/tests/unit/onnx/quantization/test_graph_utils.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import onnx_graphsurgeon as gs +import pytest + +from modelopt.onnx.quantization.graph_utils import find_nodes_from_convs_to_exclude + + +def _make_conv_graph(output_channels, input_channels, kernel_shape=(3, 3), name="Conv_0"): + """Build a minimal graph with a single Conv node.""" + spatial = [32, 32] + inp = gs.Variable(name="input", dtype=np.float32, shape=[1, input_channels, *spatial]) + out = gs.Variable(name="output", dtype=np.float32) + + weight_shape = (output_channels, input_channels, *kernel_shape) + weight = gs.Constant(name="weight", values=np.ones(weight_shape, dtype=np.float32)) + + conv = gs.Node( + name=name, + op="Conv", + inputs=[inp, weight], + outputs=[out], + attrs={"kernel_shape": list(kernel_shape)}, + ) + + return gs.Graph(nodes=[conv], inputs=[inp], outputs=[out], opset=13) + + +@pytest.mark.parametrize( + ("oc", "ic", "expected_excluded"), + [ + (16, 64, True), + (64, 16, True), + (8, 8, True), + (16, 16, True), + (17, 64, False), + (64, 17, False), + (17, 17, False), + (32, 32, False), + (64, 64, False), + ], +) +def test_fp8_small_channel_conv_exclusion(oc, ic, expected_excluded): + """FP8 mode should exclude Conv nodes with OC or IC <= 16.""" + graph = _make_conv_graph(output_channels=oc, input_channels=ic) + excluded = find_nodes_from_convs_to_exclude(graph, quantize_mode="fp8") + if expected_excluded: + assert "Conv_0" in excluded + else: + assert "Conv_0" not in excluded + + +def test_fp8_small_channel_exclusion_does_not_affect_int8(): + """The small-channel FP8 exclusion should not apply in int8 mode.""" + # OC=8 would be excluded in FP8 (see oc=8, ic=8 case above), but not in int8. + graph = _make_conv_graph(output_channels=8, input_channels=64, kernel_shape=(3, 3)) + excluded = find_nodes_from_convs_to_exclude(graph, quantize_mode="int8") + assert "Conv_0" not in excluded + + +@pytest.mark.parametrize( + ("oc", "ic"), + [ + (15, 64), + (64, 15), + (1, 1), + ], +) +def test_fp8_channels_below_16_excluded_by_general_check(oc, ic): + """Channels strictly < 16 are excluded by the general channel check, not the FP8 check.""" + graph = _make_conv_graph(output_channels=oc, input_channels=ic, kernel_shape=(3, 3)) + excluded = find_nodes_from_convs_to_exclude(graph, quantize_mode="fp8") + assert "Conv_0" in excluded