Skip to content

Conversation

@JerryChen97
Copy link
Contributor

@JerryChen97 JerryChen97 commented Oct 22, 2025

TODO:
[ ] bring back the doctests in condition.py skipped at #8724

Context:
JAX updated their 0.7.x and 0.8.x recently (up to 23rd Oct 2025)

Many things changed, including:

  1. primitive args/kwargs have to be hashable now, because JAX had them cachable now
  2. DynamicJaxprTracer API changed a lot
  3. quite some bugs, e.g. this critical typo, which basically breaks everything about dynamic shapes

Description of the Change:
Many patches, in alignment with Catalyst's patches jax_extra

Benefits:

Possible Drawbacks:

Related GitHub Issues:
[sc-105046]

@JerryChen97 JerryChen97 requested a review from a team as a code owner October 22, 2025 18:38
@JerryChen97 JerryChen97 self-assigned this Oct 22, 2025
@github-actions
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@JerryChen97 JerryChen97 removed the request for review from a team October 22, 2025 18:39
@JerryChen97 JerryChen97 added the dependencies Pull requests that update a dependency file label Oct 22, 2025
@JerryChen97 JerryChen97 changed the title JAX bump to 0.7.0 [WIP] JAX bump to 0.7.0 Oct 23, 2025
@JerryChen97 JerryChen97 added the WIP 🚧 Work-in-progress label Oct 23, 2025
Comment on lines 276 to 279
# Create JaxprEqnContext and TracingEqn for JAX 0.7.0
# pylint: disable=import-outside-toplevel
from jax._src import compute_on, config, xla_metadata_lib
from jax._src.interpreters.partial_eval import JaxprEqnContext, TracingEqn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need an import-outside-toplevel here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


assert plxpr.eqns[0].primitive == ctrl_transform_prim
assert plxpr.eqns[0].params["control_values"] == [True]
# JAX 0.7.0 requires hashable params, so lists become tuples
Copy link
Contributor

@albi3ro albi3ro Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not entirely sure we need these comments in the test. Even as a PR comment, it is a fairly obvious change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 58 to 59
# TODO: use 0.7.0 after updating all the documentation
pip install sybil pytest "jax~=0.6.0" "jaxlib~=0.6.0" torch matplotlib pyzx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

github-merge-queue bot pushed a commit that referenced this pull request Nov 29, 2025
**Context:**
Subset of #8525 focusing on
hashability, as well as the jax API updates

**Description of the Change:**

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**
[sc-102157]

---------

Co-authored-by: Christina Lee <[email protected]>
Co-authored-by: Mudit Pandey <[email protected]>
**Context:**
During the
[updating](#8525) jax from
`0.6.2` to `0.7.x` it's a huge headache how many random bugs there exist
in new jax. However, we also do not wish to mutate any users' local
environment by patching them globally. Therefore, inspired by Catalyst's
elegant solution, here we put forward Pennylane's own and try to apply
it.

Note that to minimize any further unwanted inter-coupling between
packages, Patchers were deliberated re-invented here.

**Description of the Change:**
All the jax patches will be localized and only be used by dynamic shape
tests

**Benefits:**
Our patches to jax0.7.0 won't be cascading to any downstream package, or
other packages sharing the same jax env.

**Possible Drawbacks:**
After bump jax to 0.7.0, dynamic shape won't be automatically available.
Instead, one has to use the patching context mgr to run any dynamic
shape related code.

**Related GitHub Issues:**
@JerryChen97 JerryChen97 changed the title JAX bump to 0.7.0 Enable Dynamic Shapes in Pennylane by Patching Jax locally Dec 1, 2025
@JerryChen97
Copy link
Contributor Author

Due to multiple decision made along the way (too convoluted so I omitted the details here, sorry!), this PR does not do the jax bump any more. Instead, it targets at enablement of our dynamic shapes by reasonably patching jax bugs

github-merge-queue bot pushed a commit that referenced this pull request Dec 4, 2025
-------------------------------------------------------------------------------------------

**Context:**
After bump jax to 0.7.0 and 0.7.1, some doctests are affected. We fix
them here

**Description of the Change:**
1. `transform.py` the expected output tracing info need to be adjusted
as the hashable version
2. `condition.py` contains some dynamic shape functionalities, which
should be skipped for now, until the [patch
PR](#8525) is merged.

**Benefits:**
Documentation Tests should be back to normal and almost always green now
(except for the mysterious `qmc.py` which we have no clue yet)

**Possible Drawbacks:**
N/A

**Related GitHub Issues:**
[sc-105091]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dependencies Pull requests that update a dependency file do not merge ⚠️ Do not merge the pull request until this label is removed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants