Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/documentation-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ jobs:
run: |
pip install --upgrade --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ pennylane-catalyst pennylane-lightning
pip install -e .
# TODO: use 0.7.0 after updating all the documentation
pip install sybil pytest "jax~=0.6.0" "jaxlib~=0.6.0" torch matplotlib pyzx
pip install sybil pytest "jax==0.7.0" "jaxlib==0.7.0" torch matplotlib pyzx

- name: Print Dependencies
run: |
Expand Down
12 changes: 5 additions & 7 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _add_abstract_shapes(f):
>>> @qml.capture.FlatFn
... def f(x):
... return x + 1
>>> jax.make_jaxpr(f, abstracted_axes={0:"a"})(jnp.zeros(4))
>>> jax.make_jaxpr(f, abstracted_axes={0:"a"})(jnp.zeros(4)) # doctest: +SKIP
{ lambda ; a:i64[] b:f64[a]. let
c:f64[a] = broadcast_in_dim[
broadcast_dimensions=()
Expand All @@ -48,7 +48,7 @@ def _add_abstract_shapes(f):
] 1.0:f64[] a
d:f64[a] = add b c
in (d,) }
>>> jax.make_jaxpr(_add_abstract_shapes(f), abstracted_axes={0:"a"})(jnp.zeros(4))
>>> jax.make_jaxpr(_add_abstract_shapes(f), abstracted_axes={0:"a"})(jnp.zeros(4)) # doctest: +SKIP
{ lambda ; a:i64[] b:f64[a]. let
c:f64[a] = broadcast_in_dim[
broadcast_dimensions=()
Expand Down Expand Up @@ -434,7 +434,7 @@ def qnode(x, y):

In just-in-time (JIT) mode using the :func:`~.qjit` decorator,

.. code-block:: python
.. code-block::

dev = qml.device("lightning.qubit", wires=1)

Expand All @@ -454,10 +454,8 @@ def ansatz_false():
ansatz_true = circuit(1.4)
ansatz_false = circuit(1.6)

>>> jnp.allclose(ansatz_true, jnp.cos(1.4))
Array(True, dtype=bool)
>>> jnp.allclose(ansatz_false, jnp.cos(1.6))
Array(True, dtype=bool)
assert jnp.allclose(ansatz_true, jnp.cos(1.4))
assert jnp.allclose(ansatz_false, jnp.cos(1.6))

Additional 'else-if' clauses can also be included via the ``elif`` argument:

Expand Down
22 changes: 11 additions & 11 deletions pennylane/transforms/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,21 +253,21 @@ def circuit():
>>> jax.make_jaxpr(circuit)()
{ lambda ; . let
a:AbstractMeasurement(n_wires=None) = transform[
args_slice=slice(0, 0, None)
consts_slice=slice(0, 0, None)
args_slice=(0, 0, None)
consts_slice=(0, 0, None)
inner_jaxpr={ lambda ; . let
_:AbstractOperator() = PauliX[n_wires=1] 0:i...[]
_:AbstractOperator() = S[n_wires=1] 1:i...[]
_:AbstractOperator() = PauliX[n_wires=1] 0:i...[]
b:AbstractOperator() = S[n_wires=1] 1:i...[]
_:AbstractOperator() = PauliX[n_wires=1] 0:i64[]
_:AbstractOperator() = S[n_wires=1] 1:i64[]
_:AbstractOperator() = PauliX[n_wires=1] 0:i64[]
b:AbstractOperator() = S[n_wires=1] 1:i64[]
_:AbstractOperator() = Adjoint b
c:AbstractOperator() = PauliZ[n_wires=1] 1:i...[]
c:AbstractOperator() = PauliZ[n_wires=1] 1:i64[]
d:AbstractMeasurement(n_wires=None) = expval_obs c
in (d,) }
targs_slice=slice(0, None, None)
tkwargs={}
in (d,) }
targs_slice=(0, None, None)
tkwargs=()
transform=<transform: cancel_inverses>
]
]
in (a,) }


Expand Down
Loading