Skip to content
7 changes: 7 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
that produces a set of gate names to be used as the target gate set in decompositions.
[(#8522)](https://github.com/PennyLaneAI/pennylane/pull/8522)

* The :func:`~pennylane.transforms.decompose` transform now accepts a `minimize_work_wires` argument. With
the new graph-based decomposition system activated via :func:`~pennylane.decomposition.enable_graph`,
and `minimize_work_wires` set to `True`, the decomposition system will select decomposition rules that
minimizes the maximum number of simultaneously allocated work wires.
[(#8729)](https://github.com/PennyLaneAI/pennylane/pull/8729)
[(#8734)](https://github.com/PennyLaneAI/pennylane/pull/8734)

<h4>Pauli product measurements</h4>

* Added a :func:`~pennylane.ops.pauli_measure` that takes a Pauli product measurement.
Expand Down
64 changes: 52 additions & 12 deletions pennylane/decomposition/decomposition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class _OperatorNode:

"""

min_work_wires: int = 0
"""The minimum number of additional work wires required to decompose this operator."""

def __hash__(self) -> int:
# If the decomposition of an operator does not depend on the availability of work wires
# at all, we don't need to have multiple nodes representing the same operator with
Expand Down Expand Up @@ -119,6 +122,10 @@ class _DecompositionNode:
work_wire_spec: WorkWireSpec
num_work_wire_not_available: int
work_wire_dependent: bool = False
min_work_wires: int = 0

def __post_init__(self):
self.min_work_wires = self.min_work_wires or self.work_wire_spec.total

def count(self, op: CompressedResourceOp):
"""Find the number of occurrences of an operator in the decomposition."""
Expand Down Expand Up @@ -253,6 +260,7 @@ def __init__(
self._graph = rx.PyDiGraph()

# Construct the decomposition graph
self._min_work_wires = 0
self._start = self._graph.add_node(None)
self._construct_graph(operations)

Expand All @@ -265,6 +273,7 @@ def _construct_graph(self, operations: Iterable[Operator | CompressedResourceOp]
op = resource_rep(type(op), **op.resource_params)
idx = self._add_op_node(op, 0)
self._original_ops_indices.add(idx)
self._min_work_wires = max(self._min_work_wires, self._graph[idx].min_work_wires)

def _add_op_node(self, op: CompressedResourceOp, num_used_work_wires: int) -> int:
"""Recursively adds an operation node to the graph.
Expand Down Expand Up @@ -299,30 +308,37 @@ def _add_op_node(self, op: CompressedResourceOp, num_used_work_wires: int) -> in
self._graph.add_edge(self._start, op_node_idx, self._gate_set_weights[op.name])
return op_node_idx

update_op_to_work_wire_dependent = False
work_wire_dependent = known_work_wire_dependent
min_work_wires = -1 # use -1 to represent undetermined work wire requirement
for decomposition in self._get_decompositions(op):
d_node = self._add_decomp(decomposition, op_node, op_node_idx, num_used_work_wires)
# If any of the operator's decompositions depend on work wires, this operator
# should also depend on work wires.
if d_node and d_node.work_wire_dependent and not known_work_wire_dependent:
update_op_to_work_wire_dependent = True
if d_node and d_node.work_wire_dependent:
work_wire_dependent = True
if d_node and (min_work_wires == -1 or d_node.min_work_wires < min_work_wires):
min_work_wires = d_node.min_work_wires

# If we found that this operator depends on work wires, but it's currently recorded
# as independent of work wires, we must replace every record of this operator node
# with a new node with `work_wire_dependent` set to `True`.
if update_op_to_work_wire_dependent:
new_op_node = replace(op_node, work_wire_dependent=True)
self._all_op_indices[new_op_node] = self._all_op_indices.pop(op_node)
self._graph[op_node_idx] = new_op_node
self._op_to_op_nodes[op].remove(op_node)
self._op_to_op_nodes[op].add(new_op_node)
if not known_work_wire_dependent and work_wire_dependent:
new_op_node = replace(op_node, work_wire_dependent=True, min_work_wires=min_work_wires)
self._replace_node(op_node_idx, new_op_node)
# Also record that this operator type depends on work wires, so in the future
# when we encounter other instances of the same operator type, we correctly
# identify it as work-wire dependent.
self._work_wire_dependent_ops.add(op_node.op)

return op_node_idx

def _replace_node(self, idx: int, new_node: _OperatorNode) -> None:
original_node = self._graph[idx]
self._all_op_indices[new_node] = self._all_op_indices.pop(original_node)
self._graph[idx] = new_node
self._op_to_op_nodes[new_node.op].remove(original_node)
self._op_to_op_nodes[new_node.op].add(new_node)

def _add_decomp(
self,
rule: DecompositionRule,
Expand All @@ -349,14 +365,21 @@ def _add_decomp(
if work_wire_spec.total:
d_node.work_wire_dependent = True

# For a decomposition rule, the minimum required number of work wires of this decomposition
# rule is determined by operator that uses the MOST number of work wires.
max_op_min_work_wires = 0
for op in decomp_resource.gate_counts:
op_node_idx = self._add_op_node(op, num_used_work_wires + work_wire_spec.total)
self._graph.add_edge(op_node_idx, d_node_idx, (op_node_idx, d_node_idx))
# If any of the operators in the decomposition depends on work wires, this
# decomposition is also dependent on work wires, even it itself does not use
# any work wires.
if self._graph[op_node_idx].work_wire_dependent:
op_node = self._graph[op_node_idx]
if op_node.work_wire_dependent:
d_node.work_wire_dependent = True
max_op_min_work_wires = max(op_node.min_work_wires, max_op_min_work_wires)

d_node.min_work_wires += max_op_min_work_wires

self._graph.add_edge(d_node_idx, op_idx, 0)
return d_node
Expand Down Expand Up @@ -466,7 +489,9 @@ def _get_controlled_decompositions(self, op: CompressedResourceOp) -> list[Decom

return rules

def solve(self, num_work_wires: int | None = 0, lazy=True) -> DecompGraphSolution:
def solve(
self, num_work_wires: int | None = 0, lazy=True, minimize_work_wires=False
) -> DecompGraphSolution:
"""Solves the graph using the Dijkstra search algorithm.

Args:
Expand All @@ -475,11 +500,22 @@ def solve(self, num_work_wires: int | None = 0, lazy=True) -> DecompGraphSolutio
lazy (bool): If True, the Dijkstra search will stop once optimal decompositions are
found for all operations that the graph was initialized with. Otherwise, the
entire graph will be explored.
minimize_work_wires (bool): If True, minimize the number of additional work wires used.

Returns:
DecompGraphSolution

"""

if num_work_wires is not None and num_work_wires < self._min_work_wires:
raise DecompositionError(
f"The circuit requires at least {self._min_work_wires} work wires to decompose, "
f"the graph cannot be solved with {num_work_wires} available work wires."
)

if minimize_work_wires:
num_work_wires = self._min_work_wires

visitor = DecompositionSearchVisitor(
self._graph,
self._gate_set_weights,
Expand All @@ -504,7 +540,9 @@ def solve(self, num_work_wires: int | None = 0, lazy=True) -> DecompGraphSolutio
f"{op_names} to the target gate set {set(self._gate_set_weights)}.",
UserWarning,
)
return DecompGraphSolution(visitor, self._all_op_indices, self._op_to_op_nodes)
return DecompGraphSolution(
visitor, self._all_op_indices, self._op_to_op_nodes, num_work_wires
)


class DecompGraphSolution:
Expand Down Expand Up @@ -544,11 +582,13 @@ def __init__(
visitor: DecompositionSearchVisitor,
all_op_indices: dict[_OperatorNode, int],
op_to_op_nodes: dict[CompressedResourceOp, set[_OperatorNode]],
num_work_wires: int | None,
) -> None:
self._visitor = visitor
self._graph = visitor._graph
self._op_to_op_nodes = op_to_op_nodes
self._all_op_indices = all_op_indices
self.num_work_wires = num_work_wires

def _all_solutions(
self, visitor: DecompositionSearchVisitor, op: Operator, num_work_wires: int | None
Expand Down
22 changes: 18 additions & 4 deletions pennylane/transforms/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
stopping_condition=None,
max_expansion=None,
max_work_wires=0,
minimize_work_wires=False,
fixed_decomps=None,
alt_decomps=None,
): # pylint: disable=too-many-arguments
Expand All @@ -106,6 +107,7 @@ def __init__(
self._target_gate_names = None
self._fixed_decomps, self._alt_decomps = fixed_decomps, alt_decomps
self._max_work_wires = max_work_wires
self._minimize_work_wires = minimize_work_wires

# We use a ChainMap to store the environment frames, which allows us to push and pop
# environments without copying the interpreter instance when we evaluate a jaxpr of
Expand Down Expand Up @@ -234,9 +236,11 @@ def eval(self, jaxpr: jax.extend.core.Jaxpr, consts: Sequence, *args) -> list:
operations,
self._gate_set,
self._max_work_wires,
self._minimize_work_wires,
self._fixed_decomps,
self._alt_decomps,
)
self._max_work_wires = self._decomp_graph_solution.num_work_wires

for eq in jaxpr.eqns:
prim_type = getattr(eq.primitive, "prim_type", "")
Expand Down Expand Up @@ -354,6 +358,7 @@ def decompose(
stopping_condition=None,
max_expansion=None,
max_work_wires: int | None = 0,
minimize_work_wires: bool = False,
fixed_decomps: dict | None = None,
alt_decomps: dict | None = None,
): # pylint: disable=too-many-arguments
Expand Down Expand Up @@ -388,6 +393,8 @@ def decompose(
If ``None``, the circuit will be decomposed until the target gate set is reached.
max_work_wires (int): The maximum number of work wires that can be simultaneously
allocated. If ``None``, assume an infinite number of work wires. Defaults to ``0``.
minimize_work_wires (bool): If ``True``, minimize the number of work wires simultaneously
allocated throughout the circuit. Defaults to ``False``.
fixed_decomps (Dict[Type[Operator], DecompositionRule]): a dictionary mapping operator types
to custom decomposition rules. A decomposition rule is a quantum function decorated with
:func:`~pennylane.register_resources`. The custom decomposition rules specified here
Expand Down Expand Up @@ -748,9 +755,11 @@ def circuit():
tape.operations,
gate_set,
num_work_wires=max_work_wires,
minimize_work_wires=minimize_work_wires,
fixed_decomps=fixed_decomps,
alt_decomps=alt_decomps,
)
max_work_wires = decomp_graph_solution.num_work_wires

try:
new_ops = [
Expand Down Expand Up @@ -1028,18 +1037,23 @@ def _stopping_condition(op):
return gate_set, _stopping_condition


def _construct_and_solve_decomp_graph(
operations, target_gates, num_work_wires, fixed_decomps, alt_decomps
def _construct_and_solve_decomp_graph( # pylint: disable=too-many-arguments
operations,
target_gates,
num_work_wires,
minimize_work_wires,
fixed_decomps,
alt_decomps,
):
"""Create and solve a DecompositionGraph instance to optimize the decomposition."""

# Create the decomposition graph
decomp_graph = DecompositionGraph(
graph = DecompositionGraph(
operations,
target_gates,
fixed_decomps=fixed_decomps,
alt_decomps=alt_decomps,
)

# Find the efficient pathways to the target gate set
return decomp_graph.solve(num_work_wires=num_work_wires)
return graph.solve(num_work_wires=num_work_wires, minimize_work_wires=minimize_work_wires)
52 changes: 52 additions & 0 deletions tests/capture/transforms/test_capture_graph_decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,55 @@ def f():
qml.adjoint(qml.X(0)),
qml.adjoint(qml.Y(0)),
]

def test_minimize_work_wires(self):
"""Tests that the number of allocations can be minimized."""

class SomeOtherOp(Operation):
"""Some other operation."""

@qml.register_resources(
{qml.CNOT: 2, LargeOpDynamicWireDecomp: 2},
work_wires={"zeroed": 1},
)
def _some_decomp(wires):
with qml.allocation.allocate(1, state="zero", restored=True) as work_wires:
qml.CNOT([wires[0], work_wires[0]])
LargeOpDynamicWireDecomp(wires)
qml.CNOT([wires[0], work_wires[0]])

@DecomposeInterpreter(
gate_set={qml.Toffoli: 1, qml.CRot: 7, qml.CNOT: 1},
max_work_wires=None,
minimize_work_wires=True,
alt_decomps={
CustomOpDynamicWireDecomp: [_decomp_with_work_wire, _decomp_without_work_wire],
LargeOpDynamicWireDecomp: [_decomp2_with_work_wire],
SomeOtherOp: [_some_decomp],
},
)
def circuit():
SomeOtherOp(wires=[0, 1, 2, 3, 4])
CustomOpDynamicWireDecomp(wires=[0, 1, 4])

plxpr = jax.make_jaxpr(circuit)()
decomp = qml.tape.plxpr_to_tape(plxpr.jaxpr, plxpr.consts)
[result], _ = qml.transforms.resolve_dynamic_wires([decomp], min_int=5)

with qml.capture.pause():
with qml.queuing.AnnotatedQueue() as q:
with qml.allocation.allocate(1, state="zero", restored=True) as work_wires:
qml.CNOT([0, work_wires[0]])
with qml.allocation.allocate(1, state="zero", restored=True) as sub_work_wires:
qml.Toffoli(wires=[0, 1, sub_work_wires[0]])
_decomp_without_work_wire(wires=[sub_work_wires[0], 2, 3])
qml.Toffoli(wires=[0, 1, sub_work_wires[0]])
_decomp_without_work_wire(wires=[1, 2, 3])
qml.CNOT([0, work_wires[0]])
_decomp_with_work_wire(wires=[0, 1, 4])

expected = qml.tape.QuantumScript.from_queue(q)
[expected], _ = qml.transforms.resolve_dynamic_wires([expected], min_int=5)

for actual, exp in zip(result.operations, expected.operations, strict=True):
qml.assert_equal(actual, exp)
55 changes: 52 additions & 3 deletions tests/decomposition/test_decomposition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,16 +520,65 @@ class CustomOp(Operation): # pylint: disable=too-few-public-methods
def _custom_decomp(_):
raise NotImplementedError

@qml.register_resources({qml.X: 1})
def _another_decomp(_):
raise NotImplementedError

graph = DecompositionGraph(
[CustomOp(0), SimpleOp(0)],
gate_set={qml.X},
fixed_decomps={SimpleOp: _simple_decomp, CustomOp: _custom_decomp},
alt_decomps={SimpleOp: [_simple_decomp], CustomOp: [_custom_decomp, _another_decomp]},
)
solution = graph.solve()

assert not solution.is_solved_for(CustomOp(0))
assert solution.is_solved_for(SimpleOp(0))

def test_min_work_wires(self, _):
"""Tests that the graph tracks the minimum number of work wires."""

class SimpleOp(Operation): # pylint: disable=too-few-public-methods
"""A simple operation that does not depend on work wires."""

@qml.register_resources({qml.X: 4})
def _simple_decomp(_):
raise NotImplementedError

class CustomOp(Operation): # pylint: disable=too-few-public-methods
"""Another operation."""

@qml.register_resources({SimpleOp: 1, qml.CNOT: 4}, work_wires={"zeroed": 2})
def _custom_decomp(_):
raise NotImplementedError

class AnotherOp(Operation): # pylint: disable=too-few-public-methods
"""Some other op."""

@qml.register_resources({CustomOp: 1, qml.CNOT: 4}, work_wires={"zeroed": 2})
def _another_decomp(_):
raise NotImplementedError

@qml.register_resources({SimpleOp: 3, qml.CNOT: 4}, work_wires={"zeroed": 3})
def _yet_another_decomp(_):
raise NotImplementedError

graph = DecompositionGraph(
[AnotherOp(0)],
gate_set={qml.X, qml.CNOT},
alt_decomps={
SimpleOp: [_simple_decomp],
CustomOp: [_custom_decomp],
AnotherOp: [_another_decomp, _yet_another_decomp],
},
)
assert graph._min_work_wires == 3
with pytest.raises(DecompositionError, match="at least 3 work wires"):
graph.solve(num_work_wires=2)

solution = graph.solve(num_work_wires=None)
assert solution.decomposition(AnotherOp(0)) == _another_decomp

solution = graph.solve(num_work_wires=None, minimize_work_wires=True)
assert solution.decomposition(AnotherOp(0), num_work_wires=None) == _yet_another_decomp


@pytest.mark.unit
@patch(
Expand Down
Loading
Loading