-
Notifications
You must be signed in to change notification settings - Fork 715
Enable Dynamic Shapes in Pennylane by Patching Jax locally #8525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
Hello. You may have forgotten to update the changelog!
|
This reverts commit de4c9a5.
pennylane/workflow/_capture_qnode.py
Outdated
| # Create JaxprEqnContext and TracingEqn for JAX 0.7.0 | ||
| # pylint: disable=import-outside-toplevel | ||
| from jax._src import compute_on, config, xla_metadata_lib | ||
| from jax._src.interpreters.partial_eval import JaxprEqnContext, TracingEqn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need an import-outside-toplevel here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/capture/test_nested_plxpr.py
Outdated
|
|
||
| assert plxpr.eqns[0].primitive == ctrl_transform_prim | ||
| assert plxpr.eqns[0].params["control_values"] == [True] | ||
| # JAX 0.7.0 requires hashable params, so lists become tuples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not entirely sure we need these comments in the test. Even as a PR comment, it is a fairly obvious change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # TODO: use 0.7.0 after updating all the documentation | ||
| pip install sybil pytest "jax~=0.6.0" "jaxlib~=0.6.0" torch matplotlib pyzx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
**Context:** Subset of #8525 focusing on hashability, as well as the jax API updates **Description of the Change:** **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-102157] --------- Co-authored-by: Christina Lee <[email protected]> Co-authored-by: Mudit Pandey <[email protected]>
**Context:** During the [updating](#8525) jax from `0.6.2` to `0.7.x` it's a huge headache how many random bugs there exist in new jax. However, we also do not wish to mutate any users' local environment by patching them globally. Therefore, inspired by Catalyst's elegant solution, here we put forward Pennylane's own and try to apply it. Note that to minimize any further unwanted inter-coupling between packages, Patchers were deliberated re-invented here. **Description of the Change:** All the jax patches will be localized and only be used by dynamic shape tests **Benefits:** Our patches to jax0.7.0 won't be cascading to any downstream package, or other packages sharing the same jax env. **Possible Drawbacks:** After bump jax to 0.7.0, dynamic shape won't be automatically available. Instead, one has to use the patching context mgr to run any dynamic shape related code. **Related GitHub Issues:**
|
Due to multiple decision made along the way (too convoluted so I omitted the details here, sorry!), this PR does not do the jax bump any more. Instead, it targets at enablement of our dynamic shapes by reasonably patching jax bugs |
------------------------------------------------------------------------------------------- **Context:** After bump jax to 0.7.0 and 0.7.1, some doctests are affected. We fix them here **Description of the Change:** 1. `transform.py` the expected output tracing info need to be adjusted as the hashable version 2. `condition.py` contains some dynamic shape functionalities, which should be skipped for now, until the [patch PR](#8525) is merged. **Benefits:** Documentation Tests should be back to normal and almost always green now (except for the mysterious `qmc.py` which we have no clue yet) **Possible Drawbacks:** N/A **Related GitHub Issues:** [sc-105091]
TODO:
[ ] bring back the doctests in
condition.pyskipped at #8724Context:
JAX updated their 0.7.x and 0.8.x recently (up to 23rd Oct 2025)
Many things changed, including:
DynamicJaxprTracerAPI changed a lotDescription of the Change:
Many patches, in alignment with Catalyst's patches
jax_extraBenefits:
Possible Drawbacks:
Related GitHub Issues:
[sc-105046]