Skip to content
9 changes: 5 additions & 4 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
that produces a set of gate names to be used as the target gate set in decompositions.
[(#8522)](https://github.com/PennyLaneAI/pennylane/pull/8522)

* The :class:`~pennylane.decomposition.DecompositionGraph` now tracks the minimum number of
dynamic wire allocations required to solve the circuit, and provides a `minimize_work_wires`
option that enables the graph to select the best decomposition rules while minimizing the
number of additional allocations of work wires.
* The :func:`~pennylane.transforms.decompose` transform now accepts a `minimize_work_wires` argument. With
the new graph-based decomposition system activated via :func:`~pennylane.decomposition.enable_graph`,
and `minimize_work_wires` set to `True`, the decomposition system will select decomposition rules that
minimizes the maximum number of simultaneously allocated work wires.
[(#8729)](https://github.com/PennyLaneAI/pennylane/pull/8729)
[(#8734)](https://github.com/PennyLaneAI/pennylane/pull/8734)

<h4>Pauli product measurements</h4>

Expand Down
6 changes: 5 additions & 1 deletion pennylane/decomposition/decomposition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,9 @@ def solve(
f"{op_names} to the target gate set {set(self._gate_set_weights)}.",
UserWarning,
)
return DecompGraphSolution(visitor, self._all_op_indices, self._op_to_op_nodes)
return DecompGraphSolution(
visitor, self._all_op_indices, self._op_to_op_nodes, num_work_wires
)


class DecompGraphSolution:
Expand Down Expand Up @@ -580,11 +582,13 @@ def __init__(
visitor: DecompositionSearchVisitor,
all_op_indices: dict[_OperatorNode, int],
op_to_op_nodes: dict[CompressedResourceOp, set[_OperatorNode]],
num_work_wires: int | None,
) -> None:
self._visitor = visitor
self._graph = visitor._graph
self._op_to_op_nodes = op_to_op_nodes
self._all_op_indices = all_op_indices
self.num_work_wires = num_work_wires

def _all_solutions(
self, visitor: DecompositionSearchVisitor, op: Operator, num_work_wires: int | None
Expand Down
1 change: 1 addition & 0 deletions pennylane/devices/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def decompose( # pylint: disable = too-many-positional-arguments
operations=decomposable_ops,
target_gates=target_gates,
num_work_wires=num_available_work_wires,
minimize_work_wires=False,
fixed_decomps=None,
alt_decomps=None,
)
Expand Down
22 changes: 18 additions & 4 deletions pennylane/transforms/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
stopping_condition=None,
max_expansion=None,
max_work_wires=0,
minimize_work_wires=False,
fixed_decomps=None,
alt_decomps=None,
): # pylint: disable=too-many-arguments
Expand All @@ -106,6 +107,7 @@ def __init__(
self._target_gate_names = None
self._fixed_decomps, self._alt_decomps = fixed_decomps, alt_decomps
self._max_work_wires = max_work_wires
self._minimize_work_wires = minimize_work_wires

# We use a ChainMap to store the environment frames, which allows us to push and pop
# environments without copying the interpreter instance when we evaluate a jaxpr of
Expand Down Expand Up @@ -234,9 +236,11 @@ def eval(self, jaxpr: jax.extend.core.Jaxpr, consts: Sequence, *args) -> list:
operations,
self._gate_set,
self._max_work_wires,
self._minimize_work_wires,
self._fixed_decomps,
self._alt_decomps,
)
self._max_work_wires = self._decomp_graph_solution.num_work_wires

for eq in jaxpr.eqns:
prim_type = getattr(eq.primitive, "prim_type", "")
Expand Down Expand Up @@ -354,6 +358,7 @@ def decompose(
stopping_condition=None,
max_expansion=None,
max_work_wires: int | None = 0,
minimize_work_wires: bool = False,
fixed_decomps: dict | None = None,
alt_decomps: dict | None = None,
): # pylint: disable=too-many-arguments
Expand Down Expand Up @@ -388,6 +393,8 @@ def decompose(
If ``None``, the circuit will be decomposed until the target gate set is reached.
max_work_wires (int): The maximum number of work wires that can be simultaneously
allocated. If ``None``, assume an infinite number of work wires. Defaults to ``0``.
minimize_work_wires (bool): If ``True``, minimize the number of work wires simultaneously
allocated throughout the circuit. Defaults to ``False``.
fixed_decomps (Dict[Type[Operator], DecompositionRule]): a dictionary mapping operator types
to custom decomposition rules. A decomposition rule is a quantum function decorated with
:func:`~pennylane.register_resources`. The custom decomposition rules specified here
Expand Down Expand Up @@ -748,9 +755,11 @@ def circuit():
tape.operations,
gate_set,
num_work_wires=max_work_wires,
minimize_work_wires=minimize_work_wires,
fixed_decomps=fixed_decomps,
alt_decomps=alt_decomps,
)
max_work_wires = decomp_graph_solution.num_work_wires

try:
new_ops = [
Expand Down Expand Up @@ -1028,18 +1037,23 @@ def _stopping_condition(op):
return gate_set, _stopping_condition


def _construct_and_solve_decomp_graph(
operations, target_gates, num_work_wires, fixed_decomps, alt_decomps
def _construct_and_solve_decomp_graph( # pylint: disable=too-many-arguments
operations,
target_gates,
num_work_wires,
minimize_work_wires,
fixed_decomps,
alt_decomps,
):
"""Create and solve a DecompositionGraph instance to optimize the decomposition."""

# Create the decomposition graph
decomp_graph = DecompositionGraph(
graph = DecompositionGraph(
operations,
target_gates,
fixed_decomps=fixed_decomps,
alt_decomps=alt_decomps,
)

# Find the efficient pathways to the target gate set
return decomp_graph.solve(num_work_wires=num_work_wires)
return graph.solve(num_work_wires=num_work_wires, minimize_work_wires=minimize_work_wires)
60 changes: 60 additions & 0 deletions tests/capture/transforms/test_capture_graph_decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,63 @@ def f():
qml.adjoint(qml.X(0)),
qml.adjoint(qml.Y(0)),
]

def test_minimize_work_wires(self):
"""Tests that the number of allocations can be minimized."""

class SomeOtherOp(Operation):
"""Some other operation."""

@qml.register_resources(
{qml.CNOT: 2, LargeOpDynamicWireDecomp: 2},
work_wires={"zeroed": 1},
)
def _some_decomp(wires):
with qml.allocation.allocate(1, state="zero", restored=True) as work_wires:
qml.CNOT([wires[0], work_wires[0]])
LargeOpDynamicWireDecomp(wires)
qml.CNOT([wires[0], work_wires[0]])

@DecomposeInterpreter(
gate_set={qml.Toffoli: 1, qml.CRot: 7, qml.CNOT: 1},
max_work_wires=None,
minimize_work_wires=True,
alt_decomps={
CustomOpDynamicWireDecomp: [_decomp_with_work_wire, _decomp_without_work_wire],
LargeOpDynamicWireDecomp: [_decomp2_with_work_wire],
SomeOtherOp: [_some_decomp],
},
)
def circuit():
SomeOtherOp(wires=[0, 1, 2, 3, 4])
CustomOpDynamicWireDecomp(wires=[0, 1, 4])

plxpr = jax.make_jaxpr(circuit)()
decomp = qml.tape.plxpr_to_tape(plxpr.jaxpr, plxpr.consts)
[result], _ = qml.transforms.resolve_dynamic_wires([decomp], min_int=5)

with qml.capture.pause():
with qml.queuing.AnnotatedQueue() as q:
# The only decomposition rule available for SomeOtherOp
with qml.allocation.allocate(1, state="zero", restored=True) as work_wires:
qml.CNOT([0, work_wires[0]])
# The only decomposition available for LargeOpDynamicWireDecomp
with qml.allocation.allocate(1, state="zero", restored=True) as sub_work_wires:
qml.Toffoli(wires=[0, 1, sub_work_wires[0]])
# At this point, to minimize the number of work wires allocated, we
# select the decomposition rule that does not use any work wires for
# the CustomOpDynamicWireDecomp at the very bottom of the chain
_decomp_without_work_wire(wires=[sub_work_wires[0], 2, 3])
qml.Toffoli(wires=[0, 1, sub_work_wires[0]])
_decomp_without_work_wire(wires=[1, 2, 3])
qml.CNOT([0, work_wires[0]])
# Since the SomeOtherOp that came before already used two work wires, this
# second CustomOpDynamicWireDecomp should be free to use up to two work wires,
# and we verify that this is indeed what happens.
_decomp_with_work_wire(wires=[0, 1, 4])

expected = qml.tape.QuantumScript.from_queue(q)
[expected], _ = qml.transforms.resolve_dynamic_wires([expected], min_int=5)

for actual, exp in zip(result.operations, expected.operations, strict=True):
qml.assert_equal(actual, exp)
59 changes: 59 additions & 0 deletions tests/transforms/test_decompose_transform_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,65 @@ def test_dynamic_work_wire_allocation(self, num_work_wires, expected_gate_count)
gate_counts[type(op)] += 1
assert gate_counts == expected_gate_count

def test_minimize_work_wires(self):
"""Tests that the number of allocations can be minimized."""

class SomeOtherOp(Operation): # pylint: disable=too-few-public-methods
"""Some other operation."""

@qml.register_resources(
{qml.CNOT: 2, LargeOpDynamicWireDecomp: 2},
work_wires={"zeroed": 1},
)
def _some_decomp(wires):
with qml.allocation.allocate(1, state="zero", restored=True) as work_wires:
qml.CNOT([wires[0], work_wires[0]])
LargeOpDynamicWireDecomp(wires)
qml.CNOT([wires[0], work_wires[0]])

op1 = SomeOtherOp(wires=[0, 1, 2, 3, 4])
op2 = CustomOpDynamicWireDecomp(wires=[0, 1, 4])
tape = qml.tape.QuantumScript([op1, op2])

[decomp], _ = qml.transforms.decompose(
[tape],
gate_set={qml.Toffoli: 1, qml.CRot: 7, qml.CNOT: 1},
max_work_wires=None,
minimize_work_wires=True,
alt_decomps={
CustomOpDynamicWireDecomp: [_decomp_with_work_wire, _decomp_without_work_wire],
LargeOpDynamicWireDecomp: [_decomp2_with_work_wire],
SomeOtherOp: [_some_decomp],
},
)

[result], _ = qml.transforms.resolve_dynamic_wires([decomp], min_int=5)

with qml.queuing.AnnotatedQueue() as q:
# The only decomposition rule available for SomeOtherOp
with qml.allocation.allocate(1, state="zero", restored=True) as work_wires:
qml.CNOT([0, work_wires[0]])
# The only decomposition available for LargeOpDynamicWireDecomp
with qml.allocation.allocate(1, state="zero", restored=True) as sub_work_wires:
qml.Toffoli(wires=[0, 1, sub_work_wires[0]])
# At this point, to minimize the number of work wires allocated, we
# select the decomposition rule that does not use any work wires for
# the CustomOpDynamicWireDecomp at the very bottom of the chain
_decomp_without_work_wire(wires=[sub_work_wires[0], 2, 3])
qml.Toffoli(wires=[0, 1, sub_work_wires[0]])
_decomp_without_work_wire(wires=[1, 2, 3])
qml.CNOT([0, work_wires[0]])
# Since the SomeOtherOp that came before already used two work wires, this
# second CustomOpDynamicWireDecomp should be free to use up to two work wires,
# and we verify that this is indeed what happens.
_decomp_with_work_wire(wires=[0, 1, 4])

expected = qml.tape.QuantumScript.from_queue(q)
[expected], _ = qml.transforms.resolve_dynamic_wires([expected], min_int=5)

for actual, exp in zip(result.operations, expected.operations, strict=True):
qml.assert_equal(actual, exp)


@pytest.mark.capture
@pytest.mark.system
Expand Down