diff --git a/modelopt/onnx/graph_surgery/__init__.py b/modelopt/onnx/graph_surgery/__init__.py index 06ac87c0b..1b3e75144 100644 --- a/modelopt/onnx/graph_surgery/__init__.py +++ b/modelopt/onnx/graph_surgery/__init__.py @@ -47,13 +47,13 @@ ... ) >>> # Add cross-attention KV cache outputs to encoder (GenAI compatible) >>> add_cross_kv_to_encoder( - ... encoder_path="encoder_model.onnx", + ... model_path="encoder_model.onnx", ... output_path="encoder_with_kv.onnx", ... hf_model_id="openai/whisper-large-v3-turbo", ... ) >>> # Standalone FP16 to BF16 conversion >>> convert_fp16_to_bf16( - ... input_path="model_fp16.onnx", + ... model_path="model_fp16.onnx", ... output_path="model_bf16.onnx", ... ) >>> @@ -69,9 +69,70 @@ from .gqa_replacement import replace_attention_with_gqa from .utils.dtype_conversion import convert_fp16_to_bf16 +_SURGERY_REGISTRY = { + "replace-gqa": replace_attention_with_gqa, + "add-cross-kv": add_cross_kv_to_encoder, + "convert-bf16": convert_fp16_to_bf16, + "transpose-dq": transpose_dequantize_linear_weights, +} + + +def get_available_surgeries() -> list[str]: + """Return a list of all registered graph surgery names.""" + return list(_SURGERY_REGISTRY.keys()) + + +def run_graph_surgery( + surgery_name: str, + model_path: str, + output_path: str, + **kwargs, +): + """Run a graph surgery by name. + + This is the unified entry point for all graph surgeries. It dispatches + to the appropriate surgery function based on the surgery name. + + When new surgeries are added to the registry, they are automatically + available through this function without any changes to calling code. + + Args: + surgery_name: Name of the surgery to run (e.g. 'replace-gqa', 'transpose-dq'). + Use get_available_surgeries() to see all available options. + model_path: Path to the input ONNX model. + output_path: Path to save the output ONNX model. + **kwargs: Surgery-specific parameters. Passed directly to the surgery function. + + Returns: + The return value of the surgery function (typically ModelProto or dict). + + Raises: + ValueError: If surgery_name is not registered. + + Example: + >>> from modelopt.onnx.graph_surgery import run_graph_surgery, get_available_surgeries + >>> print(get_available_surgeries()) + ['replace-gqa', 'add-cross-kv', 'convert-bf16', 'transpose-dq'] + >>> run_graph_surgery( + ... "replace-gqa", + ... model_path="model.onnx", + ... output_path="model_gqa.onnx", + ... hf_model_id="meta-llama/Llama-2-7b-hf", + ... ) + """ + if surgery_name not in _SURGERY_REGISTRY: + available = ", ".join(f"'{s}'" for s in _SURGERY_REGISTRY) + raise ValueError(f"Unknown surgery: '{surgery_name}'. Available surgeries: {available}") + + func = _SURGERY_REGISTRY[surgery_name] + return func(model_path=model_path, output_path=output_path, **kwargs) + + __all__ = [ "add_cross_kv_to_encoder", "convert_fp16_to_bf16", + "get_available_surgeries", "replace_attention_with_gqa", + "run_graph_surgery", "transpose_dequantize_linear_weights", ] diff --git a/modelopt/onnx/graph_surgery/__main__.py b/modelopt/onnx/graph_surgery/__main__.py index 573f42a86..98586e142 100644 --- a/modelopt/onnx/graph_surgery/__main__.py +++ b/modelopt/onnx/graph_surgery/__main__.py @@ -271,7 +271,7 @@ def main(): from .encoder_cross_kv import add_cross_kv_to_encoder add_cross_kv_to_encoder( - encoder_path=args.input, + model_path=args.input, output_path=args.output, hf_model_id=args.model_id, hidden_state_output_name=args.hidden_state_name, @@ -288,7 +288,7 @@ def main(): from .utils.dtype_conversion import convert_fp16_to_bf16 convert_fp16_to_bf16( - input_path=args.input, + model_path=args.input, output_path=args.output, external_data=not args.no_external_data, verbose=not args.quiet, diff --git a/modelopt/onnx/graph_surgery/encoder_cross_kv.py b/modelopt/onnx/graph_surgery/encoder_cross_kv.py index 32be99185..12b0a10dd 100644 --- a/modelopt/onnx/graph_surgery/encoder_cross_kv.py +++ b/modelopt/onnx/graph_surgery/encoder_cross_kv.py @@ -320,7 +320,7 @@ def _add_cross_kv_outputs( def add_cross_kv_to_encoder( - encoder_path: str, + model_path: str, output_path: str, hf_model_id: str, hidden_state_output_name: str = "last_hidden_state", @@ -349,7 +349,7 @@ def add_cross_kv_to_encoder( 6. Generates genai_config.json and audio_processor_config.json (optional) Args: - encoder_path: Path to encoder ONNX model. + model_path: Path to encoder ONNX model. output_path: Path to save modified encoder. hf_model_id: HuggingFace model ID for loading cross-attention weights. hidden_state_output_name: Name of encoder hidden state output. @@ -369,7 +369,7 @@ def add_cross_kv_to_encoder( Example: >>> from modelopt.onnx.graph_surgery import add_cross_kv_to_encoder >>> model = add_cross_kv_to_encoder( - ... encoder_path="encoder_model.onnx", + ... model_path="encoder_model.onnx", ... output_path="encoder_model_with_kv.onnx", ... hf_model_id="openai/whisper-large-v3-turbo", ... ) @@ -380,9 +380,9 @@ def add_cross_kv_to_encoder( ) if verbose: - logger.info(f"Loading encoder model from: {encoder_path}") + logger.info(f"Loading encoder model from: {model_path}") - encoder_model = onnx.load(encoder_path, load_external_data=True) + encoder_model = onnx.load(model_path, load_external_data=True) # Detect model dtype onnx_dtype, np_dtype = detect_model_dtype(encoder_model) diff --git a/modelopt/onnx/graph_surgery/utils/dtype_conversion.py b/modelopt/onnx/graph_surgery/utils/dtype_conversion.py index 920678c1c..148caa6a4 100644 --- a/modelopt/onnx/graph_surgery/utils/dtype_conversion.py +++ b/modelopt/onnx/graph_surgery/utils/dtype_conversion.py @@ -133,7 +133,7 @@ def _convert_constant_node_to_bf16(node: onnx.NodeProto) -> bool: def convert_fp16_to_bf16( - input_path: str, + model_path: str, output_path: str, external_data: bool = True, verbose: bool = True, @@ -147,7 +147,7 @@ def convert_fp16_to_bf16( 4. All Cast nodes that target FP16 to target BF16 Args: - input_path: Path to input FP16 ONNX model. + model_path: Path to input FP16 ONNX model. output_path: Path to output BF16 ONNX model. external_data: Whether to save weights as external data. verbose: Whether to print progress messages. @@ -157,16 +157,16 @@ def convert_fp16_to_bf16( Example: >>> stats = convert_fp16_to_bf16( - ... input_path="model_fp16.onnx", + ... model_path="model_fp16.onnx", ... output_path="model_bf16.onnx", ... ) >>> logger.info(f"Converted {stats['initializers_converted']} initializers") """ if verbose: - logger.info(f"Loading model from: {input_path}") + logger.info(f"Loading model from: {model_path}") # Load model with external data - model = onnx.load(input_path, load_external_data=True) + model = onnx.load(model_path, load_external_data=True) graph = model.graph # Statistics