|
15 | 15 | """This module defines decomposition functions for unitary matrices.""" |
16 | 16 |
|
17 | 17 | import warnings |
| 18 | +from itertools import product |
18 | 19 |
|
19 | 20 | import numpy as np |
20 | 21 | from scipy import sparse |
@@ -125,8 +126,10 @@ def two_qubit_decomposition(U, wires): |
125 | 126 | where :math:`A, B, C, D` are :math:`SU(2)` operations, and the rotation angles are |
126 | 127 | computed based on features of the input unitary :math:`U`. |
127 | 128 |
|
128 | | - For the 2-CNOT case, the decomposition is currently not supported and will |
129 | | - instead produce a 3-CNOT circuit like above. |
| 129 | + For the 2-CNOT case, the decomposition is based on the |
| 130 | + real-trace criterion of Proposition III.3 in reference (2). |
| 131 | + Whenever :math:`trace(\gamma(U))` has real coefficients (equivalently :math:`trace(\gamma(U)`) ∈ R), |
| 132 | + the decomposition uses exactly two CNOT gates. |
130 | 133 |
|
131 | 134 | For a single CNOT, we have a CNOT surrounded by one :math:`SU(2)` per wire on each |
132 | 135 | side. The special case of no CNOTs simply returns a tensor product of two |
@@ -207,13 +210,20 @@ def two_qubit_decomposition(U, wires): |
207 | 210 | global_phase += _decompose_3_cnots(U, wires, global_phase) |
208 | 211 | else: |
209 | 212 | num_cnots = _compute_num_cnots(U) |
210 | | - # Use the 3-CNOT case for num_cnots=2 as well because we do not have a reliably |
211 | | - # correct implementation of the 2-CNOT case right now. |
| 213 | + |
| 214 | + elifs = [(num_cnots == 1, _decompose_1_cnot)] |
| 215 | + |
| 216 | + # The 2-CNOT decomposition relies on sorting eigenvalues, which is not supported |
| 217 | + # with abstract tracers when capture is enabled. In that case, we fall back |
| 218 | + # to the 3-CNOT decomposition. |
| 219 | + if not capture.enabled(): |
| 220 | + elifs.append((num_cnots == 2, _decompose_2_cnots)) |
| 221 | + |
212 | 222 | global_phase += ops.cond( |
213 | 223 | num_cnots == 0, |
214 | 224 | _decompose_0_cnots, |
215 | 225 | _decompose_3_cnots, |
216 | | - elifs=[(num_cnots == 1, _decompose_1_cnot)], |
| 226 | + elifs=elifs, |
217 | 227 | )(U, wires, global_phase) |
218 | 228 |
|
219 | 229 | if _is_jax_jit(U) or not math.allclose(global_phase, 0): |
@@ -377,13 +387,20 @@ def two_qubit_decomp_rule(U, wires, **__): |
377 | 387 |
|
378 | 388 | U, initial_phase = math.convert_to_su4(U, return_global_phase=True) |
379 | 389 | num_cnots = _compute_num_cnots(U) |
380 | | - # Use the 3-CNOT case for num_cnots=2 as well because we do not have a reliably |
381 | | - # correct implementation of the 2-CNOT case right now. |
| 390 | + |
| 391 | + elifs = [(num_cnots == 1, _decompose_1_cnot)] |
| 392 | + |
| 393 | + # The 2-CNOT decomposition relies on sorting eigenvalues, which is not supported |
| 394 | + # with abstract tracers when capture is enabled. In that case, we fall back |
| 395 | + # to the 3-CNOT decomposition. |
| 396 | + if not capture.enabled(): |
| 397 | + elifs.append((num_cnots == 2, _decompose_2_cnots)) |
| 398 | + |
382 | 399 | additional_phase = ops.cond( |
383 | 400 | num_cnots == 0, |
384 | 401 | _decompose_0_cnots, |
385 | 402 | _decompose_3_cnots, |
386 | | - elifs=[(num_cnots == 1, _decompose_1_cnot)], |
| 403 | + elifs=elifs, |
387 | 404 | )(U, wires, initial_phase) |
388 | 405 | total_phase = initial_phase + additional_phase |
389 | 406 | ops.cond(math.logical_not(math.allclose(total_phase, 0)), ops.GlobalPhase)(-total_phase) |
@@ -474,15 +491,15 @@ def multi_qubit_decomp_rule(U, wires, **__): |
474 | 491 |
|
475 | 492 |
|
476 | 493 | def _compute_num_cnots(U): |
477 | | - r"""Compute the number of CNOTs required to implement a U in SU(4). |
| 494 | + r""" |
| 495 | + Compute the number of CNOTs required to implement a U in SU(4). |
478 | 496 | This is based on the trace of |
479 | 497 |
|
480 | 498 | .. math:: |
481 | 499 |
|
482 | 500 | \gamma(U) = (E^\dag U E) (E^\dag U E)^T, |
483 | 501 |
|
484 | 502 | and follows the arguments of this paper: https://arxiv.org/abs/quant-ph/0308045. |
485 | | -
|
486 | 503 | """ |
487 | 504 |
|
488 | 505 | U = math.dot(E_dag, math.dot(U, E)) |
@@ -604,6 +621,190 @@ def _decompose_1_cnot(U, wires, initial_phase): |
604 | 621 | return math.cast_like(-np.pi / 4, initial_phase) |
605 | 622 |
|
606 | 623 |
|
| 624 | +def _get_basis_and_eigenvalues(M): |
| 625 | + r""" |
| 626 | + Helper to diagonalize M, extract diagonal eigenvalues, and sort canonically. |
| 627 | + Returns eigenvalues and basis O such that :math:`D = O^T M O` is diagonal with |
| 628 | + sorted eigenvalues. |
| 629 | + """ |
| 630 | + # pylint: disable=protected-access |
| 631 | + # Use split_eigh to get a basis (ignoring its mixed eigenvalues) |
| 632 | + _, O = _real_imag_split_eigh(M, 1.0) |
| 633 | + |
| 634 | + # Compute true eigenvalues: D = O.T @ M @ O |
| 635 | + d_mat = math.dot(math.transpose(O), math.dot(M, O)) |
| 636 | + eigvals = math.diag(d_mat) |
| 637 | + |
| 638 | + # Canonical Sort: Real part descending, then Imaginary part descending |
| 639 | + r = np.round(math.real(eigvals), 6) |
| 640 | + i = np.round(math.imag(eigvals), 6) |
| 641 | + |
| 642 | + # Sort by imag then real to ensure deterministic aligning of U and V |
| 643 | + sort_indices = np.lexsort((i, r)) |
| 644 | + |
| 645 | + # Reorder |
| 646 | + eigvals_sorted = eigvals[sort_indices] |
| 647 | + O_sorted = O[:, sort_indices] |
| 648 | + |
| 649 | + # Enforce determinant 1 (SO(4)) |
| 650 | + det = math.linalg.det(O_sorted) |
| 651 | + if math.real(det) < 0: |
| 652 | + O_sorted = math.set_index(O_sorted, (slice(None), 3), -O_sorted[:, 3]) |
| 653 | + |
| 654 | + return eigvals_sorted, O_sorted |
| 655 | + |
| 656 | + |
| 657 | +def _find_so4_decomposition(U, u_mag, O_u, candidates): |
| 658 | + r""" |
| 659 | + Performs the exhaustive search for alpha, beta, and signs |
| 660 | + to ensure Real :math:`SO(4)` correction gates. |
| 661 | + Returns the best found parameters along with the basis O_v and |
| 662 | + v_mag for the kernel V. |
| 663 | + """ |
| 664 | + CNOT10_np = np.array([[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0]], dtype=complex) |
| 665 | + CNOT10 = math.cast_like(CNOT10_np, U) |
| 666 | + best_result = None |
| 667 | + min_error = np.inf |
| 668 | + |
| 669 | + # Search Parameter Permutations |
| 670 | + for alpha, beta in candidates: |
| 671 | + # Construct Kernel V |
| 672 | + rza = ops.RZ.compute_matrix(alpha) |
| 673 | + rxb = ops.RX.compute_matrix(beta) |
| 674 | + kernel_inner = math.kron(rza, rxb) |
| 675 | + V = _multidot(CNOT10, kernel_inner, CNOT10) |
| 676 | + |
| 677 | + # Get basis for V |
| 678 | + v_mag = _multidot(math.cast_like(E_dag, U), V, math.cast_like(E, U)) |
| 679 | + gamma_v = math.dot(v_mag, math.T(v_mag)) |
| 680 | + _, O_v = _get_basis_and_eigenvalues(gamma_v) |
| 681 | + |
| 682 | + # P = v^dag O_v, Q = O_u^T u |
| 683 | + P = math.dot(math.conj(math.T(v_mag)), O_v) |
| 684 | + Q = math.dot(math.T(O_u), u_mag) |
| 685 | + |
| 686 | + # Sign Search to fix Gauge Freedom |
| 687 | + for signs in product([1.0, -1.0], repeat=4): |
| 688 | + # Enforce determinant +1 to stay in SO(4) |
| 689 | + if np.prod(signs) < 0: |
| 690 | + continue |
| 691 | + |
| 692 | + S_diag = math.cast_like(signs, P) |
| 693 | + # Broadcasting P * S is faster than matrix mult |
| 694 | + # Metric: sum of absolute imaginary parts (should be 0 for valid decomposition) |
| 695 | + R_trial = math.dot(P * S_diag, Q) |
| 696 | + error = math.sum(math.abs(math.imag(R_trial))) |
| 697 | + |
| 698 | + if error < min_error: |
| 699 | + min_error = error |
| 700 | + best_result = (alpha, beta, signs, O_v, v_mag) |
| 701 | + |
| 702 | + # FALLBACK CHECK: |
| 703 | + # If the best error we found is still "large" (e.g., > 1e-5), |
| 704 | + # then we failed to find a valid Real-valued decomposition. |
| 705 | + # We return None to signal that 2-CNOT decomposition is likely impossible/unsafe. |
| 706 | + if min_error > 1e-5: |
| 707 | + return None |
| 708 | + |
| 709 | + return best_result |
| 710 | + |
| 711 | + |
| 712 | +def _decompose_2_cnots(U, wires, initial_phase): |
| 713 | + r"""Decompose a two-qubit unitary known to require exactly 2 CNOTs. |
| 714 | +
|
| 715 | + The resulting circuit has the following canonical form: |
| 716 | + 0: ──A──╭X──RZ(a)──╭X──C──┤ |
| 717 | + 1: ──B──╰●──RX(b)──╰●──D──┤ |
| 718 | +
|
| 719 | + where A, B, C, D are single-qubit unitaries (:math:`SU(2)` gates) and a, b |
| 720 | + are rotation angles determined by the entanglement properties of the unitary. |
| 721 | +
|
| 722 | + This implementation is based on the work by Shende, Bullock, and Markov and |
| 723 | + this is done following the methods in https://arxiv.org/abs/quant-ph/0308045. |
| 724 | +
|
| 725 | + The decomposition relies on the Magic Basis, where local unitaries correspond to |
| 726 | + orthogonal matrices (:math:`SO(4)`), and the entangling power of an operator is captured by |
| 727 | + the spectrum (eigenvalues) of the symmetric invariant matrix: :math:`\gamma(U) = (E^\dagger U E) (E^\dagger U E)^T`. |
| 728 | +
|
| 729 | + The algorithm proceeds in four main steps: |
| 730 | +
|
| 731 | + 1. The invariant matrix :math:`\gamma(U)` is computed. Its eigenvalues |
| 732 | + are extracted. These eigenvalues come in conjugate pairs, |
| 733 | + and their phases directly determine the rotation parameters a, and b |
| 734 | + needed for the circuit's core. |
| 735 | +
|
| 736 | + 2. A reference operator V is built using the |
| 737 | + calculated parameters a, b. This operator represents the ideal |
| 738 | + 2-CNOT circuit core: :math:`CNOT \cdot (RZ(\alpha) \otimes RX(\beta)) \cdot CNOT`. |
| 739 | + By construction, V is isospectral to U in the Magic Basis. |
| 740 | +
|
| 741 | + 3. To transform the input U into V using only |
| 742 | + local gates, we align their eigenbases. This involves diagonalizing :math:`\gamma(U)` and |
| 743 | + :math:`\gamma(V)` and sorting their eigenvectors canonically to match corresponding subspaces. |
| 744 | +
|
| 745 | + 4. Since eigenvectors are defined only up to a sign (parity), there are |
| 746 | + :math:`2^4 = 16` possible alignments. The algorithm exhaustively searches these combinations |
| 747 | + to find the specific parity that results in a valid, real-valued local transformation |
| 748 | + (a matrix in :math:`SO(4)`). This transformation determines the local gates A, B, C, D. |
| 749 | +
|
| 750 | + The final circuit is then constructed by applying these local gates around the V. |
| 751 | + """ |
| 752 | + # pylint: disable=too-many-locals |
| 753 | + # 1. Compute gamma(U) |
| 754 | + u_mag = _multidot(math.cast_like(E_dag, U), U, math.cast_like(E, U)) |
| 755 | + gamma_u = math.dot(u_mag, math.T(u_mag)) |
| 756 | + |
| 757 | + # 2. Extract interaction parameters |
| 758 | + eig_u, O_u = _get_basis_and_eigenvalues(gamma_u) |
| 759 | + # Extract phases and sort to group conjugate pairs |
| 760 | + abs_angles = np.sort(np.abs(math.angle(eig_u))) |
| 761 | + # Pick distinct representatives |
| 762 | + theta1 = abs_angles[3] |
| 763 | + theta2 = abs_angles[1] |
| 764 | + # Map to circuit parameters |
| 765 | + a_calc = (theta1 + theta2) / 2 |
| 766 | + b_calc = (theta1 - theta2) / 2 |
| 767 | + candidates = [(a_calc, b_calc), (b_calc, a_calc)] |
| 768 | + |
| 769 | + # 3. Perform Search (Delegated to helper) |
| 770 | + result = _find_so4_decomposition(U, u_mag, O_u, candidates) |
| 771 | + |
| 772 | + # SAFETY FALLBACK to _decompose_3_cnots: |
| 773 | + # If the 2-CNOT search failed (result is None), it means U is likely |
| 774 | + # not a 2-CNOT gate (despite trace invariants) or numerical noise is too high. |
| 775 | + # We fall back to the generic 3-CNOT decomposition to guarantee correctness. |
| 776 | + if result is None: |
| 777 | + return _decompose_3_cnots(U, wires, initial_phase) |
| 778 | + |
| 779 | + alpha_f, beta_f, signs_f, O_v, v_mag = result |
| 780 | + |
| 781 | + # 4. Compute Local Gates L (Left) and R (Right) in SO(4) |
| 782 | + S_mat = math.diag(signs_f) |
| 783 | + L = _multidot(O_u, S_mat, math.T(O_v)) |
| 784 | + R = _multidot(math.conj(math.T(v_mag)), math.T(L), u_mag) |
| 785 | + |
| 786 | + # 5. Convert to local gates |
| 787 | + AB = _multidot(math.cast_like(E, U), L, math.cast_like(E_dag, U)) |
| 788 | + CD = _multidot(math.cast_like(E, U), R, math.cast_like(E_dag, U)) |
| 789 | + |
| 790 | + A, B = math.decomposition.su2su2_to_tensor_products(AB) |
| 791 | + C, D = math.decomposition.su2su2_to_tensor_products(CD) |
| 792 | + |
| 793 | + # 6. Queue Circuit |
| 794 | + ops.QubitUnitary(C, wires=wires[0]) |
| 795 | + ops.QubitUnitary(D, wires=wires[1]) |
| 796 | + |
| 797 | + ops.CNOT(wires=[wires[1], wires[0]]) |
| 798 | + ops.RZ(alpha_f, wires=wires[0]) |
| 799 | + ops.RX(beta_f, wires=wires[1]) |
| 800 | + ops.CNOT(wires=[wires[1], wires[0]]) |
| 801 | + |
| 802 | + ops.QubitUnitary(A, wires=wires[0]) |
| 803 | + ops.QubitUnitary(B, wires=wires[1]) |
| 804 | + |
| 805 | + return math.cast_like(0.0, initial_phase) |
| 806 | + |
| 807 | + |
607 | 808 | def _multidot(*matrices): |
608 | 809 | mat = matrices[0] |
609 | 810 | for m in matrices[1:]: |
|
0 commit comments