1515Tests for the TemporaryAND template.
1616"""
1717
18+ from functools import partial
19+
1820import pytest
1921
2022import pennylane as qml
2123from pennylane .ops .functions .assert_valid import _test_decomposition_rule
24+ from pennylane .templates .subroutines .arithmetic .temporary_and import _adjoint_TemporaryAND
2225
2326
2427class TestTemporaryAND :
@@ -43,7 +46,7 @@ def test_standard_validity(self):
4346 """Check the operation using the assert_valid function."""
4447
4548 op = qml .TemporaryAND (wires = [0 , "a" , 2 ], control_values = (0 , 0 ))
46- qml .ops .functions .assert_valid (op )
49+ qml .ops .functions .assert_valid (op , skip_decomp_matrix_check = True )
4750
4851 def test_correctness (self ):
4952 """Tests the correctness of the TemporaryAND operator.
@@ -110,6 +113,112 @@ def test_and_decompositions(self):
110113 for rule in qml .list_decomps (qml .TemporaryAND ):
111114 _test_decomposition_rule (qml .TemporaryAND ([0 , 1 , 2 ], control_values = (0 , 0 )), rule )
112115
116+ @pytest .mark .parametrize ("control_values" , [(0 , 0 ), (0 , 1 ), (1 , 0 ), (1 , 1 )])
117+ def test_adjoint_temporary_and_decomposition (self , control_values ):
118+ """
119+ Validate the MCM-based decomposition of Adjoint(TemporaryAND).
120+ """
121+ sys_wires = [0 , 1 , 2 ]
122+ work_wires = [3 ] # auxiliary qubit for deferred measure
123+ dev = qml .device ("default.qubit" , wires = sys_wires + work_wires )
124+
125+ @qml .qnode (dev )
126+ def circuit (a , b ):
127+ qml .BasisState (qml .math .array ([a , b , 0 ], dtype = int ), wires = sys_wires )
128+ qml .TemporaryAND (wires = sys_wires , control_values = control_values )
129+ _adjoint_TemporaryAND (wires = sys_wires )
130+ return qml .probs (wires = sys_wires )
131+
132+ for a in (0 , 1 ):
133+ for b in (0 , 1 ):
134+ probs = circuit (a , b )
135+ idx = (a << 2 ) | (b << 1 )
136+ assert qml .math .allclose (
137+ probs [idx ], 1.0
138+ ), f"Failed for a={ a } , b={ b } , cv={ control_values } "
139+
140+ @pytest .mark .usefixtures ("enable_graph_decomposition" )
141+ def test_adjoint_temporary_and_integration (self ):
142+ wires = [0 , 1 , "aux0" , 2 ]
143+ gate_set = {"X" , "T" , "Adjoint(T)" , "Hadamard" , "CX" , "CZ" , "MidMeasureMP" , "Adjoint(S)" }
144+
145+ @qml .set_shots (1 )
146+ @qml .qnode (qml .device ("default.qubit" , wires = wires ), interface = None )
147+ @partial (
148+ qml .transforms .decompose ,
149+ gate_set = gate_set ,
150+ fixed_decomps = {
151+ qml .Select : qml .templates .subroutines .select ._select_decomp_unary # pylint: disable=protected-access
152+ },
153+ )
154+ def circuit ():
155+ ops = [qml .Z (2 ) for _ in range (4 )]
156+ qml .Select (ops , control = [0 , 1 ], work_wires = ["aux0" ], partial = True )
157+ return qml .sample (wires = wires )
158+
159+ tape = qml .workflow .construct_tape (circuit )()
160+ expected_operators = [
161+ qml .X (0 ),
162+ qml .X (1 ),
163+ qml .H ("aux0" ),
164+ qml .adjoint (qml .T ("aux0" )),
165+ qml .H ("aux0" ),
166+ qml .CZ (wires = [1 , "aux0" ]),
167+ qml .H ("aux0" ),
168+ qml .T ("aux0" ),
169+ qml .H ("aux0" ),
170+ qml .CZ (wires = [0 , "aux0" ]),
171+ qml .H ("aux0" ),
172+ qml .adjoint (qml .T ("aux0" )),
173+ qml .H ("aux0" ),
174+ qml .CZ (wires = [1 , "aux0" ]),
175+ qml .H ("aux0" ),
176+ qml .T ("aux0" ),
177+ qml .H ("aux0" ),
178+ qml .adjoint (qml .S ("aux0" )),
179+ qml .X (0 ),
180+ qml .X (1 ),
181+ qml .CZ (wires = ["aux0" , 2 ]),
182+ qml .H ("aux0" ),
183+ qml .CZ (wires = [0 , "aux0" ]),
184+ qml .H ("aux0" ),
185+ qml .X ("aux0" ),
186+ qml .CZ (wires = ["aux0" , 2 ]),
187+ qml .H ("aux0" ),
188+ qml .CZ (wires = [0 , "aux0" ]),
189+ qml .H ("aux0" ),
190+ qml .H ("aux0" ),
191+ qml .CZ (wires = [1 , "aux0" ]),
192+ qml .H ("aux0" ),
193+ qml .CZ (wires = ["aux0" , 2 ]),
194+ qml .H ("aux0" ),
195+ qml .CZ (wires = [0 , "aux0" ]),
196+ qml .H ("aux0" ),
197+ qml .CZ (wires = ["aux0" , 2 ]),
198+ qml .H ("aux0" ),
199+ qml .measurements .MidMeasureMP (wires = ["aux0" ], postselect = None , reset = True ),
200+ "ConditionalCZ" ,
201+ ]
202+
203+ for op , exp_op in zip (tape .operations , expected_operators ):
204+ # manual check: each MidMeasure has a unique ID, which prevents
205+ # qml.equal from treating two MidMeasure as equal.
206+ if isinstance (op , qml .measurements .MidMeasureMP ):
207+ assert op .wires == exp_op .wires
208+ assert op .postselect == exp_op .postselect
209+ assert op .reset == exp_op .reset
210+
211+ # manual check for the conditional operator
212+ elif isinstance (op , qml .ops .op_math .condition .Conditional ):
213+ assert exp_op == "ConditionalCZ"
214+ assert isinstance (op .base , qml .CZ )
215+ assert list (op .base .wires ) == [0 , 1 ]
216+ meas = op .meas_val # same as the expr passed to qml.cond
217+ assert list (meas .wires ) == ["aux0" ]
218+
219+ else :
220+ qml .assert_equal (op , exp_op )
221+
113222 @pytest .mark .parametrize ("control_values" , [(0 , 0 ), (0 , 1 ), (1 , 0 ), (1 , 1 )])
114223 def test_compute_matrix_temporary_and (self , control_values ):
115224 """Tests that the matrix of the TemporaryAND operator is correct."""
0 commit comments