Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,10 @@ A warning message has been added to :doc:`Building a plugin <../development/plug
* Fixes a bug where :func:`~.change_op_basis` cannot be captured when the `uncompute_op` is left out.
[(#8695)](https://github.com/PennyLaneAI/pennylane/pull/8695)

* Fixes a bug in :func:`~qml.ops.rs_decomposition` where correct solution candidates were being rejected
due to some incorrect GCD computations.
[(#8625)](https://github.com/PennyLaneAI/pennylane/pull/8625)

* Fixes a bug where decomposition rules are sometimes incorrectly disregarded by the `DecompositionGraph` when a higher level
decomposition rule uses dynamically allocated work wires.
[(#8725)](https://github.com/PennyLaneAI/pennylane/pull/8725)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/ops/op_math/decompositions/norm_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _(elem1, elem2):
@_gcd.register(ZOmega)
def _(elem1, elem2):
while elem2 != 0:
elem1, elem2 = elem2, elem2 % elem1
elem1, elem2 = elem2, elem1 % elem2
return elem1


Expand Down
19 changes: 16 additions & 3 deletions pennylane/ops/op_math/decompositions/rings.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,21 @@ def __mod__(self, other: ZSqrtTwo | int | float) -> ZSqrtTwo:
if isinstance(other, int) or (isinstance(other, float) and other.is_integer()):
return ZSqrtTwo(self.a % int(other), self.b % int(other))

if self in (zero := ZSqrtTwo(0, 0), other): # Trivial cases
return zero

d = abs(other)
n1, n2 = (self.a * other.a - 2 * self.b * other.b), (self.b * other.a - self.a * other.b)
return self - ZSqrtTwo(round(n1 / d), round(n2 / d)) * other
if (dv := ZSqrtTwo(n1 // d, n2 // d)) != ZSqrtTwo(0, 0): # Check if floor division works
return self - dv * other

# If floor division leads to a zero divisor, search neighbours.
dv_a, dv_b = max(round(n1 / d), dv.a), dv.b
if dv_a == dv.a:
dv_b = max(round(n2 / d), dv.b)

# Adjust the sign difference based on the adjusted values.
return (-1) ** (dv_a != dv.a or dv_b != dv.b) * (self - ZSqrtTwo(dv_a, dv_b) * other)

@property
def flatten(self: ZSqrtTwo) -> list[int]:
Expand Down Expand Up @@ -288,8 +300,9 @@ def __floordiv__(self, other: int) -> ZOmega:

def __mod__(self, other: ZOmega) -> ZOmega:
d = abs(other)
n = self * other.conj() * ((other * other.conj()).adj2())
return ZOmega(*[(s + d // 2) // d for s in n.flatten]) * other - self
n = self * other.conj() * (other * other.conj()).adj2()
r = other * ZOmega(*[(s + d // 2) // d for s in n.flatten])
return self - r if abs(self) > abs(r) else r - self

@classmethod
def from_sqrt_pair(cls, alpha: ZSqrtTwo, beta: ZSqrtTwo, shift: ZOmega) -> ZOmega:
Expand Down
72 changes: 66 additions & 6 deletions tests/ops/op_math/decompositions/test_norm_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pennylane.ops.op_math.decompositions.norm_solver import (
_factorize_prime_zomega,
_factorize_prime_zsqrt_two,
_gcd,
_integer_factorize,
_primality_test,
_prime_factorize,
Expand All @@ -30,6 +31,45 @@
from pennylane.ops.op_math.decompositions.rings import ZOmega, ZSqrtTwo


class TestGCD:
"""Tests for the GCD function."""

@pytest.mark.parametrize(
"a, b, expected",
[
(ZSqrtTwo(28), ZSqrtTwo(12), ZSqrtTwo(4)),
(ZSqrtTwo(15), ZSqrtTwo(25), ZSqrtTwo(5)),
(ZSqrtTwo(81), ZSqrtTwo(63), ZSqrtTwo(9)),
(ZSqrtTwo(144), ZSqrtTwo(108), ZSqrtTwo(36)),
(ZSqrtTwo(23, 72), ZSqrtTwo(23, 72), ZSqrtTwo(23, 72)),
(ZSqrtTwo(28, 15), ZSqrtTwo(12, 25), ZSqrtTwo(58, 41)),
(ZSqrtTwo(35, 42), ZSqrtTwo(22, 16), ZSqrtTwo(11, 8)),
(ZSqrtTwo(-1, 7), ZSqrtTwo(7, 0), ZSqrtTwo(1, 0)),
],
)
def test_gcd_zsqrt_two(self, a, b, expected):
"""Test the GCD function."""
assert _gcd(a, b) == expected
res1, res2 = a / expected, b / expected
assert res1 * expected == a
assert res2 * expected == b

@pytest.mark.parametrize(
"a, b, expected",
[
(ZOmega(d=28), ZOmega(d=12), ZOmega(d=4)),
(ZOmega(d=15), ZOmega(d=25), ZOmega(d=5)),
(ZOmega(d=81), ZOmega(d=63), ZOmega(d=9)),
(ZOmega(d=144), ZOmega(d=108), ZOmega(d=36)),
(ZOmega(d=28, b=12), ZOmega(d=32, b=44), ZOmega(d=4)),
(ZOmega(a=81, c=63), ZOmega(a=36, c=42), ZOmega(a=3, c=3)),
],
)
def test_gcd_zomega(self, a, b, expected):
"""Test the GCD function."""
assert _gcd(a, b) == expected


class TestFactorization:
"""Tests for the factorization functions."""

Expand Down Expand Up @@ -84,12 +124,12 @@ def test_factorize_prime_zsqrt_two(self, num, expected):
@pytest.mark.parametrize(
"num, expected",
[
(3, ZOmega(d=3)),
(3, ZOmega(a=1, c=1, d=1)),
(27, None),
(5, ZOmega(b=-2, d=-1)),
(5, ZOmega(b=-1, d=2)),
(7, None),
(11, ZOmega(d=11)),
(13, ZOmega(b=-2, d=-3)),
(11, ZOmega(a=1, c=1, d=3)),
(13, ZOmega(b=2, d=3)),
],
)
def test_factorize_prime_zomega(self, num, expected):
Expand Down Expand Up @@ -147,14 +187,34 @@ def test_sqrt_modulo_p(self, nums, expected):
(ZSqrtTwo(2, -1), ZOmega(a=1, b=-1, c=0, d=0)),
(ZSqrtTwo(7, 0), None),
(ZSqrtTwo(23, 0), None),
(ZSqrtTwo(7, 2), ZOmega(a=1, b=1, c=1, d=2)),
(ZSqrtTwo(7, 2), -ZOmega(a=1, b=1, c=2, d=-1)),
(ZSqrtTwo(17, 0), None),
(ZSqrtTwo(5, 2), ZOmega(a=-2, b=-1, c=0, d=0)),
(ZSqrtTwo(13, 6), ZOmega(a=0, b=2, c=3, d=0)),
(ZSqrtTwo(13, 6), ZOmega(a=3, b=0, c=0, d=-2)),
],
)
def test_solve_diophantine(self, num, expected):
"""Test `solve_diophantine` solves diophantine equation."""
assert _solve_diophantine(num) == expected
if expected is not None:
assert (expected.conj() * expected).to_sqrt_two() == num

@pytest.mark.parametrize(
"num, expected, factor",
[
(
ZOmega(-26687414, 10541729, 10614512, 40727366),
ZOmega(-30805761, 23432014, -2332111, -20133911),
52,
),
(
ZOmega(-22067493351, 22078644868, 52098814989, 16270802723),
ZOmega(-4737137864, -21764478939, 70433513740, -5852668010),
73,
),
],
)
def test_solve_diophantine_large_number(self, num, expected, factor):
"""Test `solve_diophantine` solves diophantine equation."""
xi = ZSqrtTwo(2**factor) - num.norm().to_sqrt_two()
assert _solve_diophantine(xi) == expected
27 changes: 27 additions & 0 deletions tests/ops/op_math/decompositions/test_rings.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ def test_arithmetic_operations(self):
assert z2.to_omega() == ZOmega(a=-4, b=0, c=4, d=3)
assert 1 - ZSqrtTwo(1, 2) == ZSqrtTwo(0, -2)

@pytest.mark.parametrize(
"z1, z2, expected",
[
(ZSqrtTwo(-1, 7), ZSqrtTwo(0, 1), ZSqrtTwo(1, 0)),
(ZSqrtTwo(-1, 7), ZSqrtTwo(0, 7), ZSqrtTwo(13, 0)),
(ZSqrtTwo(0, 0), ZSqrtTwo(2, 3), ZSqrtTwo(0, 0)),
(ZSqrtTwo(3, -9), ZSqrtTwo(1, -3), ZSqrtTwo(0, 0)),
],
)
def test_arithmetic_modulo(self, z1, z2, expected):
"""Test arithmetic modulo operation on ZSqrtTwo."""
assert z1 % z2 == expected

def test_arithmetic_errors(self):
"""Test that arithmetic operations raise errors for invalid types."""
z1 = ZSqrtTwo(1, 2)
Expand Down Expand Up @@ -141,6 +154,20 @@ def test_arithmetic_operations(self):
assert (z1 - ZOmega(a=2, b=2, c=2)).to_sqrt_two() == ZSqrtTwo(a=4, b=1)
assert 1 - ZOmega() == ZOmega(d=1)

@pytest.mark.parametrize(
"z1, z2, expected",
[
(ZOmega(24, 12, 32, 4), ZOmega(6, 3, 8, 1), ZOmega(0, 0, 0, 0)),
(ZOmega(3, -1, 4, 2), ZOmega(1, 2, 0, 3), ZOmega(0, 0, 0, 2)),
(ZOmega(0, 0, 0, 0), ZOmega(2, 3, 7, 5), ZOmega(0, 0, 0, 0)),
(ZOmega(12, -7, 4, 9), ZOmega(0, 0, 5, -3), ZOmega(1, 1, 1, -1)),
(ZOmega(-17, 22, -9, 5), ZOmega(3, -4, 2, -1), ZOmega(-1, 2, -3, 1)),
],
)
def test_arithmetic_modulo(self, z1, z2, expected):
"""Test arithmetic modulo operation on ZOmega."""
assert z1 % z2 == expected

def test_arithmetic_errors(self):
"""Test that arithmetic operations raise errors for invalid types."""
z1 = ZOmega(1, 2, 3, 4)
Expand Down