Skip to content

Commit 83e9232

Browse files
jzaia18mudit2812mehrdad2mdime10
authored
Implement new mlir_specs Python function to get resource counts from MLIR passes (#2238)
**Context:** There is currently no way to view the impact to runtime costs of a given MLIR pass from the python frontend of PennyLane. **Description of the Change:** Creates a new `mlir_specs` function, in addition to a `specs_collect` function which uses xDSL to allow inspection of compilation passes written in MLIR. **Benefits:** Allows users to evaluate the impact on the IR of various MLIR passes. **Possible Drawbacks:** Not integrated with `qml.specs()`, this will be handled in a followup PR. **Related GitHub Issues:** [sc-103510] Migrated from PennyLaneAI/pennylane#8660 See also the frontend integration for this PR: PennyLaneAI/pennylane#8606 --------- Co-authored-by: Mudit Pandey <[email protected]> Co-authored-by: Mudit Pandey <[email protected]> Co-authored-by: Mehrdad Malek <[email protected]> Co-authored-by: David Ittah <[email protected]>
1 parent 3752ec6 commit 83e9232

File tree

8 files changed

+1417
-5
lines changed

8 files changed

+1417
-5
lines changed

doc/releases/changelog-dev.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,11 @@
355355

356356
<h3>Internal changes ⚙️</h3>
357357

358+
* A new `catalyst.python_interface.inspection.mlir_specs` method has been added to facilitate
359+
PennyLane's new pass-by-pass specs feature. This function returns information gathered by parsing
360+
the xDSL generated by a given QJIT object, such as gate counts, measurements, or qubit allocations.
361+
[(#2238)](https://github.com/PennyLaneAI/catalyst/pull/2238)
362+
358363
* Resource tracking now writes out at device destruction time instead of qubit deallocation
359364
time. The written resources will be the total amount of resources collected throughout the
360365
lifetime of the execution. For executions that split work between multiple functions,

frontend/catalyst/python_interface/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Unified Compiler API for integration of Catalyst with xDSL."""
1515

1616
from .compiler import Compiler
17-
from .inspection import QMLCollector
17+
from .inspection import QMLCollector, mlir_specs
1818
from .parser import QuantumParser
1919
from .pass_api import compiler_transform
2020

@@ -23,4 +23,5 @@
2323
"compiler_transform",
2424
"QuantumParser",
2525
"QMLCollector",
26+
"mlir_specs",
2627
]

frontend/catalyst/python_interface/inspection/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,14 @@
1919
from .collector import QMLCollector
2020
from .draw import draw
2121
from .mlir_graph import generate_mlir_graph
22+
from .specs import mlir_specs
23+
from .specs_collector import ResourcesResult, specs_collect
2224

23-
__all__ = ["QMLCollector", "draw", "generate_mlir_graph"]
25+
__all__ = [
26+
"QMLCollector",
27+
"draw",
28+
"generate_mlir_graph",
29+
"specs_collect",
30+
"ResourcesResult",
31+
"mlir_specs",
32+
]
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2025 Xanadu Quantum Technologies Inc.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""This file contains the implementation of the `specs` function for the Unified Compiler."""
16+
17+
from __future__ import annotations
18+
19+
import warnings
20+
from typing import TYPE_CHECKING, Literal
21+
22+
from ..compiler import Compiler
23+
from .specs_collector import ResourcesResult, specs_collect
24+
from .xdsl_conversion import get_mlir_module
25+
26+
if TYPE_CHECKING:
27+
from catalyst.jit import QJIT
28+
29+
30+
class StopCompilation(Exception):
31+
"""Custom exception to stop compilation early when the desired specs level is reached."""
32+
33+
34+
def mlir_specs(
35+
qnode: QJIT, level: int | tuple[int] | list[int] | Literal["all"], *args, **kwargs
36+
) -> ResourcesResult | dict[str, ResourcesResult]:
37+
"""Compute the specs used for a circuit at the level of an MLIR pass.
38+
39+
Args:
40+
qnode (QNode): The (QJIT'd) qnode to get the specs for
41+
level (int | tuple[int] | list[int] | "all"): The MLIR pass level to get the specs for
42+
*args: Positional arguments to pass to the QNode
43+
**kwargs: Keyword arguments to pass to the QNode
44+
45+
Returns:
46+
ResourcesResult | dict[str, ResourcesResult]: The resources for the circuit at the
47+
specified level
48+
"""
49+
cache: dict[int, tuple[ResourcesResult, str]] = {}
50+
51+
if args or kwargs:
52+
warnings.warn(
53+
"The `specs` function does not yet support dynamic arguments, "
54+
"so the results may not reflect information provided by the arguments.",
55+
UserWarning,
56+
)
57+
58+
max_level = level
59+
if max_level == "all":
60+
max_level = None
61+
elif isinstance(level, (tuple, list)):
62+
max_level = max(level)
63+
elif not isinstance(level, int):
64+
raise ValueError("The `level` argument must be an int, a tuple/list of ints, or 'all'.")
65+
66+
def _specs_callback(previous_pass, module, next_pass, pass_level=0):
67+
"""Callback function for gathering circuit specs."""
68+
69+
pass_instance = previous_pass if previous_pass else next_pass
70+
result = specs_collect(module)
71+
72+
pass_name = pass_instance.name if hasattr(pass_instance, "name") else pass_instance
73+
cache[pass_level] = (
74+
result,
75+
pass_name if pass_level else "Before MLIR Passes",
76+
)
77+
78+
if max_level is not None and pass_level >= max_level:
79+
raise StopCompilation("Stopping compilation after reaching max specs level.")
80+
81+
mlir_module = get_mlir_module(qnode, args, kwargs)
82+
try:
83+
Compiler.run(mlir_module, callback=_specs_callback)
84+
except StopCompilation:
85+
# We use StopCompilation to interrupt the compilation once we reach
86+
# the desired level
87+
pass
88+
89+
if level == "all":
90+
return {f"{cache[lvl][1]} (MLIR-{lvl})": cache[lvl][0] for lvl in sorted(cache.keys())}
91+
92+
if isinstance(level, (tuple, list)):
93+
if any(lvl not in cache for lvl in level):
94+
missing = [str(lvl) for lvl in level if lvl not in cache]
95+
raise ValueError(
96+
f"Requested specs levels {', '.join(missing)} not found in MLIR pass list."
97+
)
98+
return {f"{cache[lvl][1]} (MLIR-{lvl})": cache[lvl][0] for lvl in level if lvl in cache}
99+
100+
# Just one level was specified
101+
if level not in cache:
102+
raise ValueError(f"Requested specs level {level} not found in MLIR pass list.")
103+
return cache[level][0]

0 commit comments

Comments
 (0)