Skip to content

Commit a7b7398

Browse files
authored
Merge branch 'master' into master
2 parents d998f45 + 4b740f0 commit a7b7398

File tree

3 files changed

+239
-12
lines changed

3 files changed

+239
-12
lines changed

doc/releases/changelog-dev.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030

3131
<h3>Improvements 🛠</h3>
3232

33+
* Added a new decomposition, `_decompose_2_cnots`, for the two-qubit decomposition for `QubitUnitary`.
34+
It supports the analytical decomposition a two-qubit unitary known to require exactly 2 CNOTs.
35+
[(#8666)](https://github.com/PennyLaneAI/pennylane/issues/8666)
36+
3337
* Arithmetic dunder methods (`__add__`, `__mul__`, `__rmul__`) have been added to
3438
:class:`~.transforms.core.TransformDispatcher`, :class:`~.transforms.core.TransformContainer`,
3539
and :class:`~.transforms.core.TransformProgram` to enable intuitive composition of transform
@@ -614,6 +618,7 @@ Mudit Pandey,
614618
Shuli Shu,
615619
Jay Soni,
616620
nate stemen,
621+
Theodoros Trochatos,
617622
David Wierichs,
618623
Hongsheng Zheng,
619624
Zinan Zhou

pennylane/ops/op_math/decompositions/unitary_decompositions.py

Lines changed: 211 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""This module defines decomposition functions for unitary matrices."""
1616

1717
import warnings
18+
from itertools import product
1819

1920
import numpy as np
2021
from scipy import sparse
@@ -125,8 +126,10 @@ def two_qubit_decomposition(U, wires):
125126
where :math:`A, B, C, D` are :math:`SU(2)` operations, and the rotation angles are
126127
computed based on features of the input unitary :math:`U`.
127128
128-
For the 2-CNOT case, the decomposition is currently not supported and will
129-
instead produce a 3-CNOT circuit like above.
129+
For the 2-CNOT case, the decomposition is based on the
130+
real-trace criterion of Proposition III.3 in reference (2).
131+
Whenever :math:`trace(\gamma(U))` has real coefficients (equivalently :math:`trace(\gamma(U)`) ∈ R),
132+
the decomposition uses exactly two CNOT gates.
130133
131134
For a single CNOT, we have a CNOT surrounded by one :math:`SU(2)` per wire on each
132135
side. The special case of no CNOTs simply returns a tensor product of two
@@ -207,13 +210,20 @@ def two_qubit_decomposition(U, wires):
207210
global_phase += _decompose_3_cnots(U, wires, global_phase)
208211
else:
209212
num_cnots = _compute_num_cnots(U)
210-
# Use the 3-CNOT case for num_cnots=2 as well because we do not have a reliably
211-
# correct implementation of the 2-CNOT case right now.
213+
214+
elifs = [(num_cnots == 1, _decompose_1_cnot)]
215+
216+
# The 2-CNOT decomposition relies on sorting eigenvalues, which is not supported
217+
# with abstract tracers when capture is enabled. In that case, we fall back
218+
# to the 3-CNOT decomposition.
219+
if not capture.enabled():
220+
elifs.append((num_cnots == 2, _decompose_2_cnots))
221+
212222
global_phase += ops.cond(
213223
num_cnots == 0,
214224
_decompose_0_cnots,
215225
_decompose_3_cnots,
216-
elifs=[(num_cnots == 1, _decompose_1_cnot)],
226+
elifs=elifs,
217227
)(U, wires, global_phase)
218228

219229
if _is_jax_jit(U) or not math.allclose(global_phase, 0):
@@ -377,13 +387,20 @@ def two_qubit_decomp_rule(U, wires, **__):
377387

378388
U, initial_phase = math.convert_to_su4(U, return_global_phase=True)
379389
num_cnots = _compute_num_cnots(U)
380-
# Use the 3-CNOT case for num_cnots=2 as well because we do not have a reliably
381-
# correct implementation of the 2-CNOT case right now.
390+
391+
elifs = [(num_cnots == 1, _decompose_1_cnot)]
392+
393+
# The 2-CNOT decomposition relies on sorting eigenvalues, which is not supported
394+
# with abstract tracers when capture is enabled. In that case, we fall back
395+
# to the 3-CNOT decomposition.
396+
if not capture.enabled():
397+
elifs.append((num_cnots == 2, _decompose_2_cnots))
398+
382399
additional_phase = ops.cond(
383400
num_cnots == 0,
384401
_decompose_0_cnots,
385402
_decompose_3_cnots,
386-
elifs=[(num_cnots == 1, _decompose_1_cnot)],
403+
elifs=elifs,
387404
)(U, wires, initial_phase)
388405
total_phase = initial_phase + additional_phase
389406
ops.cond(math.logical_not(math.allclose(total_phase, 0)), ops.GlobalPhase)(-total_phase)
@@ -474,15 +491,15 @@ def multi_qubit_decomp_rule(U, wires, **__):
474491

475492

476493
def _compute_num_cnots(U):
477-
r"""Compute the number of CNOTs required to implement a U in SU(4).
494+
r"""
495+
Compute the number of CNOTs required to implement a U in SU(4).
478496
This is based on the trace of
479497
480498
.. math::
481499
482500
\gamma(U) = (E^\dag U E) (E^\dag U E)^T,
483501
484502
and follows the arguments of this paper: https://arxiv.org/abs/quant-ph/0308045.
485-
486503
"""
487504

488505
U = math.dot(E_dag, math.dot(U, E))
@@ -604,6 +621,190 @@ def _decompose_1_cnot(U, wires, initial_phase):
604621
return math.cast_like(-np.pi / 4, initial_phase)
605622

606623

624+
def _get_basis_and_eigenvalues(M):
625+
r"""
626+
Helper to diagonalize M, extract diagonal eigenvalues, and sort canonically.
627+
Returns eigenvalues and basis O such that :math:`D = O^T M O` is diagonal with
628+
sorted eigenvalues.
629+
"""
630+
# pylint: disable=protected-access
631+
# Use split_eigh to get a basis (ignoring its mixed eigenvalues)
632+
_, O = _real_imag_split_eigh(M, 1.0)
633+
634+
# Compute true eigenvalues: D = O.T @ M @ O
635+
d_mat = math.dot(math.transpose(O), math.dot(M, O))
636+
eigvals = math.diag(d_mat)
637+
638+
# Canonical Sort: Real part descending, then Imaginary part descending
639+
r = np.round(math.real(eigvals), 6)
640+
i = np.round(math.imag(eigvals), 6)
641+
642+
# Sort by imag then real to ensure deterministic aligning of U and V
643+
sort_indices = np.lexsort((i, r))
644+
645+
# Reorder
646+
eigvals_sorted = eigvals[sort_indices]
647+
O_sorted = O[:, sort_indices]
648+
649+
# Enforce determinant 1 (SO(4))
650+
det = math.linalg.det(O_sorted)
651+
if math.real(det) < 0:
652+
O_sorted = math.set_index(O_sorted, (slice(None), 3), -O_sorted[:, 3])
653+
654+
return eigvals_sorted, O_sorted
655+
656+
657+
def _find_so4_decomposition(U, u_mag, O_u, candidates):
658+
r"""
659+
Performs the exhaustive search for alpha, beta, and signs
660+
to ensure Real :math:`SO(4)` correction gates.
661+
Returns the best found parameters along with the basis O_v and
662+
v_mag for the kernel V.
663+
"""
664+
CNOT10_np = np.array([[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0]], dtype=complex)
665+
CNOT10 = math.cast_like(CNOT10_np, U)
666+
best_result = None
667+
min_error = np.inf
668+
669+
# Search Parameter Permutations
670+
for alpha, beta in candidates:
671+
# Construct Kernel V
672+
rza = ops.RZ.compute_matrix(alpha)
673+
rxb = ops.RX.compute_matrix(beta)
674+
kernel_inner = math.kron(rza, rxb)
675+
V = _multidot(CNOT10, kernel_inner, CNOT10)
676+
677+
# Get basis for V
678+
v_mag = _multidot(math.cast_like(E_dag, U), V, math.cast_like(E, U))
679+
gamma_v = math.dot(v_mag, math.T(v_mag))
680+
_, O_v = _get_basis_and_eigenvalues(gamma_v)
681+
682+
# P = v^dag O_v, Q = O_u^T u
683+
P = math.dot(math.conj(math.T(v_mag)), O_v)
684+
Q = math.dot(math.T(O_u), u_mag)
685+
686+
# Sign Search to fix Gauge Freedom
687+
for signs in product([1.0, -1.0], repeat=4):
688+
# Enforce determinant +1 to stay in SO(4)
689+
if np.prod(signs) < 0:
690+
continue
691+
692+
S_diag = math.cast_like(signs, P)
693+
# Broadcasting P * S is faster than matrix mult
694+
# Metric: sum of absolute imaginary parts (should be 0 for valid decomposition)
695+
R_trial = math.dot(P * S_diag, Q)
696+
error = math.sum(math.abs(math.imag(R_trial)))
697+
698+
if error < min_error:
699+
min_error = error
700+
best_result = (alpha, beta, signs, O_v, v_mag)
701+
702+
# FALLBACK CHECK:
703+
# If the best error we found is still "large" (e.g., > 1e-5),
704+
# then we failed to find a valid Real-valued decomposition.
705+
# We return None to signal that 2-CNOT decomposition is likely impossible/unsafe.
706+
if min_error > 1e-5:
707+
return None
708+
709+
return best_result
710+
711+
712+
def _decompose_2_cnots(U, wires, initial_phase):
713+
r"""Decompose a two-qubit unitary known to require exactly 2 CNOTs.
714+
715+
The resulting circuit has the following canonical form:
716+
0: ──A──╭X──RZ(a)──╭X──C──┤
717+
1: ──B──╰●──RX(b)──╰●──D──┤
718+
719+
where A, B, C, D are single-qubit unitaries (:math:`SU(2)` gates) and a, b
720+
are rotation angles determined by the entanglement properties of the unitary.
721+
722+
This implementation is based on the work by Shende, Bullock, and Markov and
723+
this is done following the methods in https://arxiv.org/abs/quant-ph/0308045.
724+
725+
The decomposition relies on the Magic Basis, where local unitaries correspond to
726+
orthogonal matrices (:math:`SO(4)`), and the entangling power of an operator is captured by
727+
the spectrum (eigenvalues) of the symmetric invariant matrix: :math:`\gamma(U) = (E^\dagger U E) (E^\dagger U E)^T`.
728+
729+
The algorithm proceeds in four main steps:
730+
731+
1. The invariant matrix :math:`\gamma(U)` is computed. Its eigenvalues
732+
are extracted. These eigenvalues come in conjugate pairs,
733+
and their phases directly determine the rotation parameters a, and b
734+
needed for the circuit's core.
735+
736+
2. A reference operator V is built using the
737+
calculated parameters a, b. This operator represents the ideal
738+
2-CNOT circuit core: :math:`CNOT \cdot (RZ(\alpha) \otimes RX(\beta)) \cdot CNOT`.
739+
By construction, V is isospectral to U in the Magic Basis.
740+
741+
3. To transform the input U into V using only
742+
local gates, we align their eigenbases. This involves diagonalizing :math:`\gamma(U)` and
743+
:math:`\gamma(V)` and sorting their eigenvectors canonically to match corresponding subspaces.
744+
745+
4. Since eigenvectors are defined only up to a sign (parity), there are
746+
:math:`2^4 = 16` possible alignments. The algorithm exhaustively searches these combinations
747+
to find the specific parity that results in a valid, real-valued local transformation
748+
(a matrix in :math:`SO(4)`). This transformation determines the local gates A, B, C, D.
749+
750+
The final circuit is then constructed by applying these local gates around the V.
751+
"""
752+
# pylint: disable=too-many-locals
753+
# 1. Compute gamma(U)
754+
u_mag = _multidot(math.cast_like(E_dag, U), U, math.cast_like(E, U))
755+
gamma_u = math.dot(u_mag, math.T(u_mag))
756+
757+
# 2. Extract interaction parameters
758+
eig_u, O_u = _get_basis_and_eigenvalues(gamma_u)
759+
# Extract phases and sort to group conjugate pairs
760+
abs_angles = np.sort(np.abs(math.angle(eig_u)))
761+
# Pick distinct representatives
762+
theta1 = abs_angles[3]
763+
theta2 = abs_angles[1]
764+
# Map to circuit parameters
765+
a_calc = (theta1 + theta2) / 2
766+
b_calc = (theta1 - theta2) / 2
767+
candidates = [(a_calc, b_calc), (b_calc, a_calc)]
768+
769+
# 3. Perform Search (Delegated to helper)
770+
result = _find_so4_decomposition(U, u_mag, O_u, candidates)
771+
772+
# SAFETY FALLBACK to _decompose_3_cnots:
773+
# If the 2-CNOT search failed (result is None), it means U is likely
774+
# not a 2-CNOT gate (despite trace invariants) or numerical noise is too high.
775+
# We fall back to the generic 3-CNOT decomposition to guarantee correctness.
776+
if result is None:
777+
return _decompose_3_cnots(U, wires, initial_phase)
778+
779+
alpha_f, beta_f, signs_f, O_v, v_mag = result
780+
781+
# 4. Compute Local Gates L (Left) and R (Right) in SO(4)
782+
S_mat = math.diag(signs_f)
783+
L = _multidot(O_u, S_mat, math.T(O_v))
784+
R = _multidot(math.conj(math.T(v_mag)), math.T(L), u_mag)
785+
786+
# 5. Convert to local gates
787+
AB = _multidot(math.cast_like(E, U), L, math.cast_like(E_dag, U))
788+
CD = _multidot(math.cast_like(E, U), R, math.cast_like(E_dag, U))
789+
790+
A, B = math.decomposition.su2su2_to_tensor_products(AB)
791+
C, D = math.decomposition.su2su2_to_tensor_products(CD)
792+
793+
# 6. Queue Circuit
794+
ops.QubitUnitary(C, wires=wires[0])
795+
ops.QubitUnitary(D, wires=wires[1])
796+
797+
ops.CNOT(wires=[wires[1], wires[0]])
798+
ops.RZ(alpha_f, wires=wires[0])
799+
ops.RX(beta_f, wires=wires[1])
800+
ops.CNOT(wires=[wires[1], wires[0]])
801+
802+
ops.QubitUnitary(A, wires=wires[0])
803+
ops.QubitUnitary(B, wires=wires[1])
804+
805+
return math.cast_like(0.0, initial_phase)
806+
807+
607808
def _multidot(*matrices):
608809
mat = matrices[0]
609810
for m in matrices[1:]:

tests/ops/op_math/test_decompositions.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,28 @@ def test_rot_decomposition_jax(self, U, expected_gates, expected_params):
982982
class TestTwoQubitUnitaryDecomposition:
983983
"""Test that two-qubit unitary operations are correctly decomposed."""
984984

985+
@pytest.mark.unit
986+
@pytest.mark.parametrize("U", samples_2_cnots)
987+
def test_compute_num_cnots_identifies_2_cnots(self, U):
988+
"""Test that the new Shende–Bullock–Markov criterion correctly
989+
classifies 2-CNOT unitaries."""
990+
U = qml.math.convert_to_su4(np.array(U))
991+
assert _compute_num_cnots(U) == 2
992+
993+
@pytest.mark.unit
994+
@pytest.mark.parametrize("U", samples_2_cnots)
995+
def test_two_qubit_decomposition_2_cnots_gate_count(self, U):
996+
"""Test that the dispatcher selects the new 2-CNOT decomposition
997+
and that the resulting circuit actually contains exactly 2 CNOTs."""
998+
U = qml.math.convert_to_su4(np.array(U))
999+
1000+
ops = two_qubit_decomposition(U, wires=[0, 1])
1001+
1002+
# Extract only CNOTs
1003+
cnot_ops = [op for op in ops if isinstance(op, qml.CNOT)]
1004+
1005+
assert len(cnot_ops) == 2, "The 2-CNOT decomposition must emit exactly 2 CNOT gates."
1006+
9851007
@pytest.mark.parametrize("U_pair", samples_su2_su2)
9861008
def test_su2su2_to_tensor_products(self, U_pair):
9871009
"""Test SU(2) x SU(2) can be correctly factored into tensor products."""
@@ -1013,14 +1035,13 @@ def test_two_qubit_decomposition_3_cnots(self, U, wires):
10131035
@pytest.mark.parametrize("U", samples_2_cnots)
10141036
def test_two_qubit_decomposition_2_cnots(self, U, wires):
10151037
"""Test that a two-qubit matrix using 2 CNOTs isolation is correctly decomposed."""
1016-
# NOTE: Currently, we defer to the 3-CNOTs function for the 2-CNOTs case.
10171038

10181039
U = qml.math.convert_to_su4(np.array(U))
10191040

10201041
assert _compute_num_cnots(U) == 2
10211042

10221043
obtained_decomposition = two_qubit_decomposition(U, wires=wires)
1023-
assert len(obtained_decomposition) == 11 # 8 # 8 would be the count with 2-CNOT circuit
1044+
assert len(obtained_decomposition) == 8
10241045

10251046
tape = qml.tape.QuantumScript(obtained_decomposition)
10261047
obtained_matrix = qml.matrix(tape, wire_order=wires)

0 commit comments

Comments
 (0)