Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/documentation-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ jobs:
run: |
pip install --upgrade --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ pennylane-catalyst pennylane-lightning
pip install -e .
# TODO: use 0.7.0 after updating all the documentation
pip install sybil pytest "jax~=0.6.0" "jaxlib~=0.6.0" torch matplotlib pyzx
pip install sybil pytest "jax==0.7.0" "jaxlib==0.7.0" torch matplotlib pyzx

- name: Print Dependencies
run: |
Expand Down
10 changes: 4 additions & 6 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _add_abstract_shapes(f):
>>> @qml.capture.FlatFn
... def f(x):
... return x + 1
>>> jax.make_jaxpr(f, abstracted_axes={0:"a"})(jnp.zeros(4))
>>> jax.make_jaxpr(f, abstracted_axes={0:"a"})(jnp.zeros(4)) # doctest: +SKIP
{ lambda ; a:i64[] b:f64[a]. let
c:f64[a] = broadcast_in_dim[
broadcast_dimensions=()
Expand All @@ -48,7 +48,7 @@ def _add_abstract_shapes(f):
] 1.0:f64[] a
d:f64[a] = add b c
in (d,) }
>>> jax.make_jaxpr(_add_abstract_shapes(f), abstracted_axes={0:"a"})(jnp.zeros(4))
>>> jax.make_jaxpr(_add_abstract_shapes(f), abstracted_axes={0:"a"})(jnp.zeros(4)) # doctest: +SKIP
{ lambda ; a:i64[] b:f64[a]. let
c:f64[a] = broadcast_in_dim[
broadcast_dimensions=()
Expand Down Expand Up @@ -454,10 +454,8 @@ def ansatz_false():
ansatz_true = circuit(1.4)
ansatz_false = circuit(1.6)

>>> jnp.allclose(ansatz_true, jnp.cos(1.4))
Array(True, dtype=bool)
>>> jnp.allclose(ansatz_false, jnp.cos(1.6))
Array(True, dtype=bool)
assert np.allclose(ansatz_true, np.cos(1.4))
assert np.allclose(ansatz_false, np.cos(1.6))

Additional 'else-if' clauses can also be included via the ``elif`` argument:

Expand Down
23 changes: 12 additions & 11 deletions pennylane/transforms/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def dummy_plxpr_transform(

.. code-block:: python

jax.config.update("jax_enable_x64", True)
qml.capture.enable()

@qml.transforms.cancel_inverses
Expand All @@ -326,21 +327,21 @@ def circuit():
>>> jax.make_jaxpr(circuit)()
{ lambda ; . let
a:AbstractMeasurement(n_wires=None) = transform[
args_slice=slice(0, 0, None)
consts_slice=slice(0, 0, None)
args_slice=(0, 0, None)
consts_slice=(0, 0, None)
inner_jaxpr={ lambda ; . let
_:AbstractOperator() = PauliX[n_wires=1] 0:i...[]
_:AbstractOperator() = S[n_wires=1] 1:i...[]
_:AbstractOperator() = PauliX[n_wires=1] 0:i...[]
b:AbstractOperator() = S[n_wires=1] 1:i...[]
_:AbstractOperator() = PauliX[n_wires=1] 0:i64[]
_:AbstractOperator() = S[n_wires=1] 1:i64[]
_:AbstractOperator() = PauliX[n_wires=1] 0:i64[]
b:AbstractOperator() = S[n_wires=1] 1:i64[]
_:AbstractOperator() = Adjoint b
c:AbstractOperator() = PauliZ[n_wires=1] 1:i...[]
c:AbstractOperator() = PauliZ[n_wires=1] 1:i64[]
d:AbstractMeasurement(n_wires=None) = expval_obs c
in (d,) }
targs_slice=slice(0, None, None)
tkwargs={}
in (d,) }
targs_slice=(0, None, None)
tkwargs=()
transform=<transform: cancel_inverses>
]
]
in (a,) }


Expand Down
Loading