Skip to content

Commit 2ba6725

Browse files
mehrdad2mmudit2812
andauthored
Remove references to the xdsl pass plugin in existing tests and docs (#2277)
**Context:** As #2169 have added auto detection of xdsl passes, we don't need to specify the pass plugin argument in the tests **Description of the Change:** Remove the reference to xdsl pass plugin within the test suit. **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-101160] --------- Co-authored-by: Mudit Pandey <[email protected]>
1 parent a795dbc commit 2ba6725

12 files changed

+49
-114
lines changed

frontend/catalyst/python_interface/doc/unified_compiler_cookbook.rst

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -664,10 +664,12 @@ currently rely on JAX’s API to lower to MLIR. This has the special
664664
effect of lowering to a specific dialect called StableHLO, which is used
665665
to represent all arithmetic operations present in the program.
666666

667-
Once lowered to MLIR, if the original ``qjit`` decorator specified the
668-
xDSL pass plugin, we pass control over to the xDSL layer, which applies
669-
all transforms that were requested by the user. We can request the use
670-
of the xDSL plugin like so:
667+
Once lowered to MLIR, if any xDSL registered passes are detected, we pass the control over to
668+
the xDSL layer, which automatically detects and applies all xDSL transforms that were requested
669+
by the user.
670+
671+
However, if you want to manually trigger the xDSL layer without using any xDSL registered passes,
672+
you can do so by specifying the ``pass_plugins`` parameter:
671673

672674
.. code-block:: python
673675
@@ -1003,9 +1005,7 @@ currently accessible as
10031005
qml.capture.enable()
10041006
dev = qml.device("lightning.qubit", wires=1)
10051007
1006-
@qml.qjit(
1007-
pass_plugins=[catalyst.passes.xdsl_plugin.getXDSLPluginAbsolutePath()]
1008-
)
1008+
@qml.qjit
10091009
@my_pass
10101010
@qml.qnode(dev)
10111011
def circuit(x):
@@ -1295,8 +1295,6 @@ will explain what is going on.
12951295

12961296
.. code-block:: python
12971297
1298-
from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
1299-
13001298
def test_h_to_x_pass_integration(run_filecheck_qjit):
13011299
"""Test that Hadamard gets converted into PauliX."""
13021300
# The original program simply applies a Hadamard to a circuit
@@ -1305,7 +1303,7 @@ will explain what is going on.
13051303
# `compiler_transform`. To make sure that the xDSL API works
13061304
# correctly, program capture must be enabled.
13071305
# qml.capture.enable()
1308-
@qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath])
1306+
@qml.qjit
13091307
@h_to_x_pass
13101308
def circuit():
13111309
# CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit

frontend/test/pytest/python_interface/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def test_qjit(self, run_filecheck_qjit):
184184
# Test that the merge_rotations_pass works as expected when used with `qjit`
185185
dev = qml.device("lightning.qubit", wires=2)
186186
187-
@qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()])
187+
@qml.qjit(target="mlir")
188188
@merge_rotations_pass
189189
@qml.qnode(dev)
190190
def circuit(x: float, y: float):

frontend/test/pytest/python_interface/inspection/test_draw_unified_compiler.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import jax
2424
import pennylane as qml
2525

26-
from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
2726
from catalyst.python_interface.inspection import draw
2827
from catalyst.python_interface.transforms import (
2928
iterative_cancel_inverses_pass,
@@ -91,9 +90,7 @@ def test_multiple_levels_xdsl(self, transforms_circuit, level, qjit, expected):
9190
)
9291

9392
if qjit:
94-
transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(
95-
transforms_circuit
96-
)
93+
transforms_circuit = qml.qjit(transforms_circuit)
9794

9895
assert draw(transforms_circuit, level=level)() == expected
9996

@@ -127,9 +124,7 @@ def test_multiple_levels_catalyst(self, transforms_circuit, level, qjit, expecte
127124
)
128125

129126
if qjit:
130-
transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(
131-
transforms_circuit
132-
)
127+
transforms_circuit = qml.qjit(transforms_circuit)
133128

134129
assert draw(transforms_circuit, level=level)() == expected
135130

@@ -162,9 +157,7 @@ def test_multiple_levels_xdsl_catalyst(self, transforms_circuit, level, qjit, ex
162157
qml.transforms.merge_rotations(transforms_circuit)
163158
)
164159
if qjit:
165-
transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(
166-
transforms_circuit
167-
)
160+
transforms_circuit = qml.qjit(transforms_circuit)
168161

169162
assert draw(transforms_circuit, level=level)() == expected
170163

@@ -208,9 +201,7 @@ def test_no_passes(self, transforms_circuit, level, qjit, expected):
208201
"""Test that if no passes are applied, the circuit is still visualized."""
209202

210203
if qjit:
211-
transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(
212-
transforms_circuit
213-
)
204+
transforms_circuit = qml.qjit(transforms_circuit)
214205

215206
assert draw(transforms_circuit, level=level)() == expected
216207

@@ -487,7 +478,7 @@ def circ(arg):
487478
def adjoint_op_not_implemented(self):
488479
"""Test that NotImplementedError is raised when AdjointOp is used."""
489480

490-
@qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
481+
@qml.qjit
491482
@qml.qnode(qml.device("lightning.qubit", wires=1))
492483
def circuit():
493484
qml.adjoint(qml.QubitUnitary)(jax.numpy.array([[0, 1], [1, 0]]), wires=[0])
@@ -499,7 +490,7 @@ def circuit():
499490
def test_cond_not_implemented(self):
500491
"""Test that NotImplementedError is raised when cond is used."""
501492

502-
@qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
493+
@qml.qjit
503494
@qml.qnode(qml.device("lightning.qubit", wires=2))
504495
def circuit():
505496
m0 = qml.measure(0, reset=False, postselect=0)
@@ -512,7 +503,7 @@ def circuit():
512503
def test_for_loop_not_implemented(self):
513504
"""Test that NotImplementedError is raised when for loop is used."""
514505

515-
@qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()], autograph=True)
506+
@qml.qjit(autograph=True)
516507
@qml.qnode(qml.device("lightning.qubit", wires=1))
517508
def circuit():
518509
for _ in range(3):
@@ -525,7 +516,7 @@ def circuit():
525516
def test_while_loop_not_implemented(self):
526517
"""Test that NotImplementedError is raised when while loop is used."""
527518

528-
@qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()], autograph=True)
519+
@qml.qjit(autograph=True)
529520
@qml.qnode(qml.device("lightning.qubit", wires=1))
530521
def circuit():
531522
i = 0

frontend/test/pytest/python_interface/inspection/test_mlir_graph.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
import pennylane as qml
3131

32-
from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
3332
from catalyst.python_interface.inspection import generate_mlir_graph
3433
from catalyst.python_interface.transforms import (
3534
iterative_cancel_inverses_pass,
@@ -73,7 +72,7 @@ def _():
7372
return qml.state()
7473

7574
if qjit:
76-
_ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_)
75+
_ = qml.qjit(_)
7776

7877
generate_mlir_graph(_)()
7978
assert collect_files(tmp_path) == {"QNode_level_0_no_transforms.svg"}
@@ -93,7 +92,7 @@ def _():
9392
return qml.state()
9493

9594
if qjit:
96-
_ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_)
95+
_ = qml.qjit(_)
9796

9897
generate_mlir_graph(_)()
9998
assert_files(
@@ -118,7 +117,7 @@ def _(x, y, w1, w2):
118117
return qml.state()
119118

120119
if qjit:
121-
_ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_)
120+
_ = qml.qjit(_)
122121

123122
generate_mlir_graph(_)(0.1, 0.2, 0, 1)
124123
assert_files(
@@ -143,7 +142,7 @@ def _(x, y, w1, w2):
143142
return qml.state()
144143

145144
if qjit:
146-
_ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_)
145+
_ = qml.qjit(_)
147146

148147
generate_mlir_graph(_)(0.1, 0.2, 0, 1)
149148
assert_files(
@@ -169,7 +168,7 @@ def _(x, y, w1, w2):
169168
return qml.state()
170169

171170
if qjit:
172-
_ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_)
171+
_ = qml.qjit(_)
173172

174173
generate_mlir_graph(_)(0.1, 0.2, 0, 1)
175174
assert_files(

frontend/test/pytest/python_interface/test_unified_compiler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def test_integration_catalyst_xdsl_pass_with_capture(self, capsys):
301301

302302
assert capture_enabled()
303303

304-
@qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
304+
@qjit
305305
@hello_world_pass
306306
@qml.qnode(qml.device("lightning.qubit", wires=2))
307307
def f(x):
@@ -319,7 +319,7 @@ def test_integration_catalyst_xdsl_pass_no_capture(self, capsys):
319319

320320
assert not capture_enabled()
321321

322-
@qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
322+
@qjit
323323
@apply_pass("hello-world")
324324
@qml.qnode(qml.device("lightning.qubit", wires=2))
325325
def f(x):
@@ -338,7 +338,7 @@ def test_integration_catalyst_mixed_passes_with_capture(self, capsys):
338338

339339
assert capture_enabled()
340340

341-
@qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
341+
@qjit
342342
@hello_world_pass
343343
@qml.transforms.cancel_inverses
344344
@qml.qnode(qml.device("lightning.qubit", wires=2))
@@ -359,7 +359,7 @@ def test_integration_catalyst_mixed_passes_no_capture(self, capsys):
359359

360360
assert not capture_enabled()
361361

362-
@qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
362+
@qjit
363363
@apply_pass("hello-world")
364364
@catalyst_cancel_inverses
365365
@qml.qnode(qml.device("lightning.qubit", wires=2))
@@ -495,7 +495,7 @@ def print_between_passes(_, module, __, pass_level=0):
495495
print("=== Between Pass ===")
496496
print(module)
497497

498-
@qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
498+
@qml.qjit
499499
@iterative_cancel_inverses_pass
500500
@merge_rotations_pass
501501
@qml.qnode(qml.device("null.qubit", wires=2))

frontend/test/pytest/python_interface/transforms/mbqc/test_xdsl_outline_state_evolution.py

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from pennylane.ftqc import RotXZX
2323

2424
from catalyst.ftqc import mbqc_pipeline
25-
from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
2625
from catalyst.python_interface.transforms import (
2726
OutlineStateEvolutionPass,
2827
convert_to_mbqc_formalism_pass,
@@ -147,10 +146,7 @@ def test_outline_state_evolution_no_error(self):
147146
"""Test outline_state_evolution_pass does not raise error for circuit with classical
148147
operations only."""
149148

150-
@qml.qjit(
151-
target="mlir",
152-
pass_plugins=[getXDSLPluginAbsolutePath()],
153-
)
149+
@qml.qjit(target="mlir")
154150
@outline_state_evolution_pass
155151
def circuit(x, y):
156152
return x * y + 5
@@ -165,10 +161,7 @@ def test_outline_state_evolution_no_terminal_op_error(self):
165161
# the program is captured.
166162
dev = qml.device("null.qubit", wires=10)
167163

168-
@qml.qjit(
169-
target="mlir",
170-
pass_plugins=[getXDSLPluginAbsolutePath()],
171-
)
164+
@qml.qjit(target="mlir")
172165
@outline_state_evolution_pass
173166
@qml.qnode(dev)
174167
def circuit():
@@ -184,10 +177,7 @@ def test_outline_state_evolution_pass_only(self, run_filecheck_qjit):
184177
"""Test the outline_state_evolution_pass only."""
185178
dev = qml.device("lightning.qubit", wires=1000)
186179

187-
@qml.qjit(
188-
target="mlir",
189-
pass_plugins=[getXDSLPluginAbsolutePath()],
190-
)
180+
@qml.qjit(target="mlir")
191181
@outline_state_evolution_pass
192182
@qml.set_shots(1000)
193183
@qml.qnode(dev)
@@ -223,11 +213,7 @@ def test_outline_state_evolution_pass_with_convert_to_mbqc_formalism(self, run_f
223213
on lightning.qubit."""
224214
dev = qml.device("lightning.qubit", wires=1000)
225215

226-
@qml.qjit(
227-
target="mlir",
228-
pass_plugins=[getXDSLPluginAbsolutePath()],
229-
pipelines=mbqc_pipeline(),
230-
)
216+
@qml.qjit(target="mlir", pipelines=mbqc_pipeline())
231217
@decompose_graph_state_pass
232218
@convert_to_mbqc_formalism_pass
233219
@outline_state_evolution_pass
@@ -273,11 +259,7 @@ def test_outline_state_evolution_pass_with_mbqc_pipeline(self, run_filecheck_qji
273259
null.qubit."""
274260
dev = qml.device("null.qubit", wires=1000)
275261

276-
@qml.qjit(
277-
target="mlir",
278-
pass_plugins=[getXDSLPluginAbsolutePath()],
279-
pipelines=mbqc_pipeline(),
280-
)
262+
@qml.qjit(target="mlir", pipelines=mbqc_pipeline())
281263
@decompose_graph_state_pass
282264
@convert_to_mbqc_formalism_pass
283265
@measurements_from_samples_pass
@@ -323,11 +305,7 @@ def test_outline_state_evolution_pass_with_mbqc_pipeline_run_on_nullqubit(self):
323305
transform pipeline can be executed on null.qubit."""
324306
dev = qml.device("null.qubit", wires=1000)
325307

326-
@qml.qjit(
327-
target="mlir",
328-
pass_plugins=[getXDSLPluginAbsolutePath()],
329-
pipelines=mbqc_pipeline(),
330-
)
308+
@qml.qjit(target="mlir", pipelines=mbqc_pipeline())
331309
@decompose_graph_state_pass
332310
@convert_to_mbqc_formalism_pass
333311
@measurements_from_samples_pass
@@ -367,10 +345,7 @@ def while_fn(i):
367345
i = i + 1
368346
return i
369347

370-
@qml.qjit(
371-
target="mlir",
372-
pass_plugins=[getXDSLPluginAbsolutePath()],
373-
)
348+
@qml.qjit(target="mlir")
374349
@outline_state_evolution_pass
375350
@qml.qnode(dev)
376351
def circuit():

frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_cancel_inverses.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import pennylane as qml
2323

24-
from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
2524
from catalyst.python_interface.transforms import (
2625
IterativeCancelInversesPass,
2726
iterative_cancel_inverses_pass,
@@ -194,7 +193,7 @@ def test_qjit(self, run_filecheck_qjit):
194193
"""Test that the IterativeCancelInversesPass works correctly with qjit."""
195194
dev = qml.device("lightning.qubit", wires=2)
196195

197-
@qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()])
196+
@qml.qjit(target="mlir")
198197
@iterative_cancel_inverses_pass
199198
@qml.qnode(dev)
200199
def circuit():
@@ -212,7 +211,7 @@ def test_qjit_no_cancellation(self, run_filecheck_qjit):
212211
there are no operations that can be cancelled."""
213212
dev = qml.device("lightning.qubit", wires=2)
214213

215-
@qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()])
214+
@qml.qjit(target="mlir")
216215
@iterative_cancel_inverses_pass
217216
@qml.qnode(dev)
218217
def circuit():

frontend/test/pytest/python_interface/transforms/quantum/test_xdsl_combine_global_phases.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import pennylane as qml
2222

23-
from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
2423
from catalyst.python_interface.transforms import (
2524
CombineGlobalPhasesPass,
2625
combine_global_phases_pass,
@@ -224,7 +223,7 @@ def test_qjit(self, run_filecheck_qjit):
224223
"""Test that the CombineGlobalPhasesPass works correctly with qjit."""
225224
dev = qml.device("lightning.qubit", wires=2)
226225

227-
@qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()])
226+
@qml.qjit(target="mlir")
228227
@combine_global_phases_pass
229228
@qml.qnode(dev)
230229
def circuit(x: float, y: float):

0 commit comments

Comments
 (0)