Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/documentation-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ jobs:
run: |
pip install --upgrade --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ pennylane-catalyst pennylane-lightning
pip install -e .
# TODO: use 0.7.1 after updating all the documentation
pip install sybil pytest "jax~=0.6.0" "jaxlib~=0.6.0" torch matplotlib pyzx
pip install sybil pytest "jax==0.7.1" "jaxlib==0.7.1" torch matplotlib pyzx

- name: Print Dependencies
run: |
Expand Down
6 changes: 3 additions & 3 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def _add_abstract_shapes(f):
Dynamic shape support currently has a lot of dragons. This function is subject to change
at any moment. Use duplicate code till reliable abstractions are found.

>>> jax.config.update("jax_dynamic_shapes", True)
>>> jax.config.update("jax_dynamic_shapes", True) # doctest: +SKIP
>>> jax.config.update("jax_enable_x64", True)
>>> @qml.capture.FlatFn
... def f(x):
... return x + 1
>>> jax.make_jaxpr(f, abstracted_axes={0:"a"})(jnp.zeros(4))
>>> jax.make_jaxpr(f, abstracted_axes={0:"a"})(jnp.zeros(4)) # doctest: +SKIP
{ lambda ; a:i64[] b:f64[a]. let
c:f64[a] = broadcast_in_dim[
broadcast_dimensions=()
Expand All @@ -48,7 +48,7 @@ def _add_abstract_shapes(f):
] 1.0:f64[] a
d:f64[a] = add b c
in (d,) }
>>> jax.make_jaxpr(_add_abstract_shapes(f), abstracted_axes={0:"a"})(jnp.zeros(4))
>>> jax.make_jaxpr(_add_abstract_shapes(f), abstracted_axes={0:"a"})(jnp.zeros(4)) # doctest: +SKIP
{ lambda ; a:i64[] b:f64[a]. let
c:f64[a] = broadcast_in_dim[
broadcast_dimensions=()
Expand Down
12 changes: 6 additions & 6 deletions pennylane/transforms/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ def circuit():
>>> jax.make_jaxpr(circuit)()
{ lambda ; . let
a:AbstractMeasurement(n_wires=None) = transform[
args_slice=slice(0, 0, None)
consts_slice=slice(0, 0, None)
args_slice=(0, 0, None)
consts_slice=(0, 0, None)
inner_jaxpr={ lambda ; . let
_:AbstractOperator() = PauliX[n_wires=1] 0:i...[]
_:AbstractOperator() = S[n_wires=1] 1:i...[]
Expand All @@ -336,11 +336,11 @@ def circuit():
_:AbstractOperator() = Adjoint b
c:AbstractOperator() = PauliZ[n_wires=1] 1:i...[]
d:AbstractMeasurement(n_wires=None) = expval_obs c
in (d,) }
targs_slice=slice(0, None, None)
tkwargs={}
in (d,) }
targs_slice=(0, None, None)
tkwargs=()
transform=<transform: cancel_inverses>
]
]
in (a,) }


Expand Down