Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions modelopt/onnx/graph_surgery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@
... )
>>> # Add cross-attention KV cache outputs to encoder (GenAI compatible)
>>> add_cross_kv_to_encoder(
... encoder_path="encoder_model.onnx",
... model_path="encoder_model.onnx",
... output_path="encoder_with_kv.onnx",
... hf_model_id="openai/whisper-large-v3-turbo",
... )
>>> # Standalone FP16 to BF16 conversion
>>> convert_fp16_to_bf16(
... input_path="model_fp16.onnx",
... model_path="model_fp16.onnx",
... output_path="model_bf16.onnx",
... )
>>>
Expand All @@ -69,9 +69,70 @@
from .gqa_replacement import replace_attention_with_gqa
from .utils.dtype_conversion import convert_fp16_to_bf16

_SURGERY_REGISTRY = {
"replace-gqa": replace_attention_with_gqa,
"add-cross-kv": add_cross_kv_to_encoder,
"convert-bf16": convert_fp16_to_bf16,
"transpose-dq": transpose_dequantize_linear_weights,
}


def get_available_surgeries() -> list[str]:
"""Return a list of all registered graph surgery names."""
return list(_SURGERY_REGISTRY.keys())


def run_graph_surgery(
surgery_name: str,
model_path: str,
output_path: str,
**kwargs,
):
"""Run a graph surgery by name.

This is the unified entry point for all graph surgeries. It dispatches
to the appropriate surgery function based on the surgery name.

When new surgeries are added to the registry, they are automatically
available through this function without any changes to calling code.

Args:
surgery_name: Name of the surgery to run (e.g. 'replace-gqa', 'transpose-dq').
Use get_available_surgeries() to see all available options.
model_path: Path to the input ONNX model.
output_path: Path to save the output ONNX model.
**kwargs: Surgery-specific parameters. Passed directly to the surgery function.

Returns:
The return value of the surgery function (typically ModelProto or dict).

Raises:
ValueError: If surgery_name is not registered.

Example:
>>> from modelopt.onnx.graph_surgery import run_graph_surgery, get_available_surgeries
>>> print(get_available_surgeries())
['replace-gqa', 'add-cross-kv', 'convert-bf16', 'transpose-dq']
>>> run_graph_surgery(
... "replace-gqa",
... model_path="model.onnx",
... output_path="model_gqa.onnx",
... hf_model_id="meta-llama/Llama-2-7b-hf",
... )
"""
if surgery_name not in _SURGERY_REGISTRY:
available = ", ".join(f"'{s}'" for s in _SURGERY_REGISTRY)
raise ValueError(f"Unknown surgery: '{surgery_name}'. Available surgeries: {available}")

func = _SURGERY_REGISTRY[surgery_name]
return func(model_path=model_path, output_path=output_path, **kwargs)


__all__ = [
"add_cross_kv_to_encoder",
"convert_fp16_to_bf16",
"get_available_surgeries",
"replace_attention_with_gqa",
"run_graph_surgery",
"transpose_dequantize_linear_weights",
]
4 changes: 2 additions & 2 deletions modelopt/onnx/graph_surgery/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def main():
from .encoder_cross_kv import add_cross_kv_to_encoder

add_cross_kv_to_encoder(
encoder_path=args.input,
model_path=args.input,
output_path=args.output,
hf_model_id=args.model_id,
hidden_state_output_name=args.hidden_state_name,
Expand All @@ -288,7 +288,7 @@ def main():
from .utils.dtype_conversion import convert_fp16_to_bf16

convert_fp16_to_bf16(
input_path=args.input,
model_path=args.input,
output_path=args.output,
external_data=not args.no_external_data,
verbose=not args.quiet,
Expand Down
10 changes: 5 additions & 5 deletions modelopt/onnx/graph_surgery/encoder_cross_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _add_cross_kv_outputs(


def add_cross_kv_to_encoder(
encoder_path: str,
model_path: str,
output_path: str,
hf_model_id: str,
hidden_state_output_name: str = "last_hidden_state",
Expand Down Expand Up @@ -349,7 +349,7 @@ def add_cross_kv_to_encoder(
6. Generates genai_config.json and audio_processor_config.json (optional)

Args:
encoder_path: Path to encoder ONNX model.
model_path: Path to encoder ONNX model.
output_path: Path to save modified encoder.
hf_model_id: HuggingFace model ID for loading cross-attention weights.
hidden_state_output_name: Name of encoder hidden state output.
Expand All @@ -369,7 +369,7 @@ def add_cross_kv_to_encoder(
Example:
>>> from modelopt.onnx.graph_surgery import add_cross_kv_to_encoder
>>> model = add_cross_kv_to_encoder(
... encoder_path="encoder_model.onnx",
... model_path="encoder_model.onnx",
... output_path="encoder_model_with_kv.onnx",
... hf_model_id="openai/whisper-large-v3-turbo",
... )
Expand All @@ -380,9 +380,9 @@ def add_cross_kv_to_encoder(
)

if verbose:
logger.info(f"Loading encoder model from: {encoder_path}")
logger.info(f"Loading encoder model from: {model_path}")

encoder_model = onnx.load(encoder_path, load_external_data=True)
encoder_model = onnx.load(model_path, load_external_data=True)

# Detect model dtype
onnx_dtype, np_dtype = detect_model_dtype(encoder_model)
Expand Down
10 changes: 5 additions & 5 deletions modelopt/onnx/graph_surgery/utils/dtype_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _convert_constant_node_to_bf16(node: onnx.NodeProto) -> bool:


def convert_fp16_to_bf16(
input_path: str,
model_path: str,
output_path: str,
external_data: bool = True,
verbose: bool = True,
Expand All @@ -147,7 +147,7 @@ def convert_fp16_to_bf16(
4. All Cast nodes that target FP16 to target BF16

Args:
input_path: Path to input FP16 ONNX model.
model_path: Path to input FP16 ONNX model.
output_path: Path to output BF16 ONNX model.
external_data: Whether to save weights as external data.
verbose: Whether to print progress messages.
Expand All @@ -157,16 +157,16 @@ def convert_fp16_to_bf16(

Example:
>>> stats = convert_fp16_to_bf16(
... input_path="model_fp16.onnx",
... model_path="model_fp16.onnx",
... output_path="model_bf16.onnx",
... )
>>> logger.info(f"Converted {stats['initializers_converted']} initializers")
"""
if verbose:
logger.info(f"Loading model from: {input_path}")
logger.info(f"Loading model from: {model_path}")

# Load model with external data
model = onnx.load(input_path, load_external_data=True)
model = onnx.load(model_path, load_external_data=True)
graph = model.graph

# Statistics
Expand Down
Loading