Skip to content

Commit f0406e3

Browse files
authored
Merge branch 'master' into compact_pauli
2 parents 1ba303c + 02c62c3 commit f0406e3

30 files changed

+1174
-117
lines changed

.github/workflows/documentation-tests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ jobs:
5555
run: |
5656
pip install --upgrade --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ pennylane-catalyst pennylane-lightning
5757
pip install -e .
58-
# TODO: use 0.7.0 after updating all the documentation
59-
pip install sybil pytest "jax~=0.6.0" "jaxlib~=0.6.0" torch matplotlib pyzx
58+
pip install sybil pytest "jax==0.7.1" "jaxlib==0.7.1" torch matplotlib pyzx
6059
6160
- name: Print Dependencies
6261
run: |

.github/workflows/interface-dependency-versions.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ on:
1616
description: The version of JAX to use for testing
1717
required: false
1818
type: string
19-
default: '0.7.0'
19+
default: '0.7.1'
2020
catalyst_jax_version:
2121
description: The version of JAX to use for testing along with Catalyst
2222
required: false
2323
type: string
24-
default: '0.7.0'
24+
default: '0.7.1'
2525
torch_version:
2626
description: The version of PyTorch to use for testing
2727
required: false

.github/workflows/interface-unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ jobs:
109109
- name: Set pytest arguments for setting warnings level
110110
id: pytest_warning_flags
111111
env:
112-
PYTEST_WARNING_ARGS: -W "${{ inputs.python_warning_level }}" --continue-on-collection-errors
112+
PYTEST_WARNING_ARGS: -W "${{ inputs.python_warning_level }}" --continue-on-collection-errors "${{ inputs.pytest_additional_args }}"
113113
run: |
114114
if [ "${{ inputs.python_warning_level }}" != "default" ]; then
115115
echo "Setting pytest warning flags"

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ doc/code/api/*
2323
timer.dat
2424
tmp/*
2525
benchmark/revisions/
26-
venv
2726
*venv*/
2827
config.toml
2928
.envrc

doc/introduction/interfaces/jax.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ JAX interface
1010

1111
.. code-block:: bash
1212
13-
pip install jax==0.7.0 jaxlib==0.7.0
13+
pip install jax==0.7.1 jaxlib==0.7.1
1414
1515
You can then import PennyLane and JAX as follows:
1616

doc/releases/changelog-dev.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
that produces a set of gate names to be used as the target gate set in decompositions.
2222
[(#8522)](https://github.com/PennyLaneAI/pennylane/pull/8522)
2323

24+
* The :class:`~pennylane.decomposition.DecompositionGraph` now tracks the minimum number of
25+
dynamic wire allocations required to solve the circuit, and provides a `minimize_work_wires`
26+
option that enables the graph to select the best decomposition rules while minimizing the
27+
number of additional allocations of work wires.
28+
[(#8729)](https://github.com/PennyLaneAI/pennylane/pull/8729)
29+
2430
<h4>Pauli product measurements</h4>
2531

2632
* Added a :func:`~pennylane.ops.pauli_measure` that takes a Pauli product measurement.
@@ -31,6 +37,16 @@
3137

3238
<h3>Improvements 🛠</h3>
3339

40+
* Added a new decomposition, `_decompose_2_cnots`, for the two-qubit decomposition for `QubitUnitary`.
41+
It supports the analytical decomposition a two-qubit unitary known to require exactly 2 CNOTs.
42+
[(#8666)](https://github.com/PennyLaneAI/pennylane/issues/8666)
43+
44+
* Arithmetic dunder methods (`__add__`, `__mul__`, `__rmul__`) have been added to
45+
:class:`~.transforms.core.TransformDispatcher`, :class:`~.transforms.core.TransformContainer`,
46+
and :class:`~.transforms.core.TransformProgram` to enable intuitive composition of transform
47+
programs using `+` and `*` operators.
48+
[(#8703)](https://github.com/PennyLaneAI/pennylane/pull/8703)
49+
3450
* Quantum compilation passes in MLIR and XDSL can now be applied using the core PennyLane transform
3551
infrastructure, instead of using Catalyst-specific tools. This is made possible by a new argument in
3652
:func:`~pennylane.transform` and `~.TransformDispatcher` called ``pass_name``, which accepts a string
@@ -360,6 +376,15 @@
360376

361377
<h3>Internal changes ⚙️</h3>
362378

379+
* Bump `jax` version to `0.7.1` for `capture` module.
380+
[(#8715)](https://github.com/PennyLaneAI/pennylane/pull/8715)
381+
382+
* Bump `jax` version to `0.7.0` for `capture` module.
383+
[(#8701)](https://github.com/PennyLaneAI/pennylane/pull/8701)
384+
385+
* Improve error handling when using PennyLane's experimental program capture functionality with an incompatible JAX version.
386+
[(#8723)](https://github.com/PennyLaneAI/pennylane/pull/8723)
387+
363388
* Bump `autoray` package version to `0.8.2`.
364389
[(#8674)](https://github.com/PennyLaneAI/pennylane/pull/8674)
365390

@@ -515,6 +540,9 @@ A warning message has been added to :doc:`Building a plugin <../development/plug
515540

516541
<h3>Bug fixes 🐛</h3>
517542

543+
* Update `interface-unit-tests.yml` to use its input parameter `pytest_additional_args` when running pytest.
544+
[(#8705)](https://github.com/PennyLaneAI/pennylane/pull/8705)
545+
518546
* Fixes a bug where in `resolve_work_wire_type` we incorrectly returned a value of `zeroed` if `both work_wires`
519547
and `base_work_wires` were empty, causing an incorrect work wire type.
520548
[(#8718)](https://github.com/PennyLaneAI/pennylane/pull/8718)
@@ -564,6 +592,10 @@ A warning message has been added to :doc:`Building a plugin <../development/plug
564592
* Fixes a bug where :func:`~.change_op_basis` cannot be captured when the `uncompute_op` is left out.
565593
[(#8695)](https://github.com/PennyLaneAI/pennylane/pull/8695)
566594

595+
* Fixes a bug where decomposition rules are sometimes incorrectly disregarded by the `DecompositionGraph` when a higher level
596+
decomposition rule uses dynamically allocated work wires.
597+
[(#8725)](https://github.com/PennyLaneAI/pennylane/pull/8725)
598+
567599
* Fixes a bug where :class:`~.ops.ChangeOpBasis` is not correctly reconstructed using `qml.pytrees.unflatten(*qml.pytrees.flatten(op))`
568600
[(#8721)](https://github.com/PennyLaneAI/pennylane/issues/8721)
569601

@@ -587,6 +619,7 @@ Mudit Pandey,
587619
Shuli Shu,
588620
Jay Soni,
589621
nate stemen,
622+
Theodoros Trochatos,
590623
David Wierichs,
591624
Hongsheng Zheng,
592625
Zinan Zhou

doc/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ appdirs
33
autograd
44
autoray
55
cachetools
6-
jax==0.7.0
7-
jaxlib==0.7.0
6+
jax==0.7.1
7+
jaxlib==0.7.1
88
mistune==0.8.4
99
m2r2
1010
# TODO: Remove once galois becomes compatible with latest numpy

pennylane/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,12 @@
204204
from packaging.version import Version as _Version
205205

206206
if _find_spec("jax") is not None:
207-
if (jax_version := _Version(_metadata_version("jax"))) > _Version("0.7.0"): # pragma: no cover
207+
if (jax_version := _Version(_metadata_version("jax"))) > _Version("0.7.1"): # pragma: no cover
208208
warnings.warn(
209-
"PennyLane is not yet compatible with JAX versions > 0.7.0. "
209+
"PennyLane is not yet compatible with JAX versions > 0.7.1. "
210210
f"You have version {jax_version} installed. "
211-
"Please downgrade JAX to 0.7.0 to avoid runtime errors using "
212-
"python -m pip install jax==0.7.0 jaxlib==0.7.0",
211+
"Please downgrade JAX to 0.7.1 to avoid runtime errors using "
212+
"python -m pip install jax==0.7.1 jaxlib==0.7.1",
213213
RuntimeWarning,
214214
)
215215

pennylane/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.44.0-dev44"
19+
__version__ = "0.44.0-dev46"

pennylane/capture/dynamic_shapes.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414
"""
1515
Contains a utility for handling inputs with dynamically shaped arrays.
1616
"""
17-
from collections.abc import Callable, Sequence
17+
from collections.abc import Callable
1818

1919
has_jax = True
2020
try:
2121
import jax
2222
from jax._src.interpreters.partial_eval import TracingEqn
23-
from jax.interpreters import partial_eval as pe
24-
except ImportError: # pragma: no cover
25-
has_jax = False # pragma: no cover
23+
24+
25+
except ImportError as e: # pragma: no cover
26+
has_jax = False
2627

2728

2829
def _get_shape_for_array(x, abstract_shapes: list, previous_ints: list) -> dict:
@@ -123,7 +124,10 @@ def f(n):
123124
124125
"""
125126
if not has_jax: # pragma: no cover
126-
raise ImportError("jax must be installed to use determine_abstracted_axes")
127+
raise ImportError(
128+
"JAX == 0.7.0 must be installed to use determine_abstracted_axes. "
129+
"Install with: pip install jax==0.7.0 jaxlib==0.7.0 "
130+
)
127131
if not jax.config.jax_dynamic_shapes:
128132
return None, ()
129133

@@ -175,18 +179,20 @@ def register_custom_staging_rule(
175179
# See also capture/intro_to_dynamic_shapes.md for dynamic shapes documentation.
176180

177181
def _tracer_and_outvar(
178-
jaxpr_trace: pe.DynamicJaxprTrace,
182+
jaxpr_trace,
179183
outvar: jax.extend.core.Var,
180184
env: dict[jax.extend.core.Var, jax.extend.core.Var],
181-
) -> tuple[pe.DynamicJaxprTracer, jax.extend.core.Var]:
185+
):
182186
"""
183187
Create a new tracer and return var from the true branch outvar.
184188
Returned vars are cached in env for use in future shapes
185189
"""
186190
if not hasattr(outvar.aval, "shape"):
187191
# JAX 0.7.0: Create variable first, then pass to DynamicJaxprTracer
188192
new_var = jaxpr_trace.frame.newvar(outvar.aval)
189-
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, outvar.aval, new_var)
193+
out_tracer = jax.interpreters.partial_eval.DynamicJaxprTracer(
194+
jaxpr_trace, outvar.aval, new_var
195+
)
190196
return out_tracer, new_var
191197
new_shape = [s if isinstance(s, int) else env[s] for s in outvar.aval.shape]
192198
if all(isinstance(s, int) for s in outvar.aval.shape):
@@ -195,15 +201,15 @@ def _tracer_and_outvar(
195201
new_aval = jax.core.DShapedArray(tuple(new_shape), outvar.aval.dtype)
196202
# JAX 0.7.0: Create variable first, then pass to DynamicJaxprTracer
197203
new_var = jaxpr_trace.frame.newvar(new_aval)
198-
out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, new_aval, new_var)
204+
out_tracer = jax.interpreters.partial_eval.DynamicJaxprTracer(
205+
jaxpr_trace, new_aval, new_var
206+
)
199207

200208
if not isinstance(outvar, jax.extend.core.Literal):
201209
env[outvar] = new_var
202210
return out_tracer, new_var
203211

204-
def custom_staging_rule(
205-
jaxpr_trace: pe.DynamicJaxprTrace, source_info, *tracers: pe.DynamicJaxprTracer, **params
206-
) -> Sequence[pe.DynamicJaxprTracer] | pe.DynamicJaxprTracer:
212+
def custom_staging_rule(jaxpr_trace, source_info, *tracers, **params):
207213
"""
208214
Add new jaxpr equation to the jaxpr_trace and return new tracers.
209215
"""
@@ -244,4 +250,4 @@ def custom_staging_rule(
244250
jaxpr_trace.frame.add_eqn(tracing_eqn)
245251
return out_tracers
246252

247-
pe.custom_staging_rules[primitive] = custom_staging_rule
253+
jax.interpreters.partial_eval.custom_staging_rules[primitive] = custom_staging_rule

0 commit comments

Comments
 (0)