-
Notifications
You must be signed in to change notification settings - Fork 715
Enable Dynamic Shapes in Pennylane by Patching Jax locally #8525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
JerryChen97
wants to merge
76
commits into
master
Choose a base branch
from
bump-jax-to-0.7.0
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 74 commits
Commits
Show all changes
76 commits
Select commit
Hold shift + click to select a range
5d32755
bump to 0.7.0
JerryChen97 a0afb1d
not the stable
JerryChen97 bfb1fc8
draft
JerryChen97 5732af2
more?
JerryChen97 6e2baed
Better cast_like and is_abstract
JerryChen97 816c6f6
After (JAX 0.7.0): vjp_func.args[0].args == ([],) for independent fun…
JerryChen97 5de0722
batches of assert value; rval to be hashable as wel
JerryChen97 53a9922
more
JerryChen97 fc3573b
more
JerryChen97 cb18461
enhance make_hashable
JerryChen97 3f2215f
Skip some weird fails for now
JerryChen97 b5150fc
skip tracer frist
JerryChen97 1c24e7e
more dynamic shape skips
JerryChen97 3c33252
patch jax
JerryChen97 d37c711
patch refactored in alignment with Catalyst
JerryChen97 1c556c1
don't import jax arbitrarily
JerryChen97 2c4d31b
more fix?
JerryChen97 4b2bed5
rm xfail
JerryChen97 6ad808d
rm xfail
JerryChen97 63ea71d
rm remains
JerryChen97 ff17844
fix all singles doubles
JerryChen97 2e4353d
Apply suggestions from code review
JerryChen97 7f47bbd
improve
JerryChen97 cb6e58f
deal with pylints
JerryChen97 9a1dfb3
disable protected-access
JerryChen97 555c8fb
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 3d08400
update make hashable
JerryChen97 d5ca52a
more robust sorted call
JerryChen97 91ef7a7
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 53b252d
refine the `_make_hashable` logic
JerryChen97 f328031
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 d161795
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 01c547c
fix output
JerryChen97 a2a3c7b
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 6bd7904
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 b8c7f24
temp fix
JerryChen97 821642b
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 a2ab4ac
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 da4d4aa
Apply suggestions from code review
JerryChen97 8d300c6
more slice improvement
JerryChen97 0995d9c
fix
JerryChen97 3107dcf
fix more
JerryChen97 46f2e55
remove restore list
JerryChen97 1ae39b7
get rid of _restore_dict (except for map_wires)
JerryChen97 f8c1da5
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 dcfab13
remove _restore_dict
JerryChen97 714387a
xfail templates
JerryChen97 b6f1410
rm unnecessary local import
JerryChen97 a56f971
two more xfail (all within templates subfolder)
JerryChen97 8ab6e2e
Try not sorting
JerryChen97 9a74f0d
improtve the dev note
JerryChen97 b40954a
clean some remains forgotten to revert
JerryChen97 f3a8016
Update pennylane/capture/custom_primitives.py
JerryChen97 35a8905
remove historical comments that not make sense anymore
JerryChen97 368024a
Update pennylane/capture/custom_primitives.py
JerryChen97 f32f8f9
oooooooops
JerryChen97 f663bf6
move all the imports to the top level
JerryChen97 abdce2b
Update pennylane/capture/jax_patches.py
JerryChen97 b4d0a6b
doc req jax fix
JerryChen97 1dcd7ed
pylint
JerryChen97 7e8849f
jax~=0.6.0 -> ==0.7.0
JerryChen97 ea66ea1
== instead of ~=; jaxlib also update
JerryChen97 b706e7f
xfailed
JerryChen97 6699308
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 de4c9a5
skip those impossible doctests
JerryChen97 3ddbf28
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 b40c40d
Revert "skip those impossible doctests"
JerryChen97 c80f6b3
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 02d2f38
Try: still use 0.6.2 for doctest
JerryChen97 aae5e6f
import from top-level
JerryChen97 2e5284e
remove too obvious comments
JerryChen97 b5c6ea4
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 fd9ac16
Merge branch 'master' into bump-jax-to-0.7.0
JerryChen97 ce13ff7
[Capture] JAX patcher for capture (#8654)
JerryChen97 5b88e87
Update pennylane/workflow/_capture_qnode.py
JerryChen97 89335e7
delete unused
JerryChen97 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
JerryChen97 marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,323 @@ | ||
| # Copyright 2025 Xanadu Quantum Technologies Inc. | ||
|
|
||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
|
|
||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
| Runtime patches for JAX internals to fix compatibility issues. | ||
|
|
||
| This module patches JAX internal functions to fix bugs that affect PennyLane's | ||
| capture mechanism. These patches are applied at module import time and are | ||
| version-specific. | ||
|
|
||
| This approach is inspired by Catalyst's JAX patches (see catalyst.jax_extras.patches), | ||
| which similarly monkey-patch JAX internal functions for compatibility. The key insight | ||
| is to add a `make_eqn` helper method to DynamicJaxprTrace that properly creates | ||
| TracingEqn objects, which JAX 0.7.0 requires but doesn't provide in all code paths. | ||
|
|
||
| JAX 0.7.0+ Patches | ||
| ------------------ | ||
|
|
||
| 1. **_dyn_shape_staging_rule**: Fixed dynamic shape handling in lax/lax.py lines 267-275. | ||
|
|
||
| The bug in the original JAX implementation: | ||
| - Uses `pe.new_jaxpr_eqn` which creates a JaxprEqn, but `trace.frame.add_eqn` | ||
| expects a TracingEqn. This causes an AssertionError. | ||
| - This bug affects ALL array creation operations with traced dimensions: | ||
| * jax.numpy.arange(traced_value) | ||
| * jax.numpy.ones((traced_value,)) | ||
| * jax.numpy.zeros(traced_value) | ||
| * Any operation using lax.broadcasted_iota with dynamic shapes | ||
|
|
||
| The fix: | ||
| - For StagingJaxprTrace (has counter): Use `pe.new_eqn_recipe` which properly | ||
| creates equation recipes for dynamic tracing. | ||
| - For DynamicJaxprTrace (no counter): Create TracingEqn directly with proper | ||
| JaxprEqnContext, avoiding the AssertionError. | ||
| - This enables array creation with traced dimensions to work correctly. | ||
|
|
||
| 2. **pjit_staging_rule**: Fixed dynamic shape handling in pjit.py lines 1894-1898. | ||
|
|
||
| The bug in the original JAX implementation: | ||
| - Uses `core.new_jaxpr_eqn` which creates a JaxprEqn, but `trace.frame.add_eqn` | ||
| expects a TracingEqn. This causes an AssertionError. | ||
| - Accesses `arg.var` which doesn't exist for DynamicJaxprTracer objects. | ||
|
|
||
| The fix: | ||
| - Use `pe.new_eqn_recipe` which properly creates equation recipes for dynamic tracing. | ||
| - Wrap outvars in DynamicJaxprTracer instances before creating the equation. | ||
| - Special handling for DynamicJaxprTrace (eval_jaxpr path) to avoid counter errors. | ||
| - This enables pjit operations with dynamic shapes to work correctly. | ||
|
|
||
| Impact | ||
| ------ | ||
| These patches fix many dynamic shape tests that were previously failing due to these JAX bugs: | ||
| - Array creation operations (jnp.arange, jnp.ones, jnp.zeros with traced dimensions) | ||
| - Cond operations with dynamic shapes | ||
| - For loop operations with dynamic shapes | ||
| - While loop operations with dynamic shapes | ||
| - Custom staging rules with dynamic shapes | ||
|
|
||
| Without these patches, any operation creating arrays with traced dimensions would fail | ||
| with AssertionError in trace.frame.add_eqn. | ||
| """ | ||
|
|
||
| # pylint: disable=too-many-arguments | ||
| # pylint: disable=unused-import,no-else-return,unidiomatic-typecheck,use-dict-literal | ||
| # pylint: disable=protected-access | ||
|
|
||
| has_jax = True | ||
| try: | ||
| import jax | ||
| from jax._src import config as jax_config | ||
| from jax._src import core, pjit, source_info_util | ||
| from jax._src.core import JaxprEqnContext, Var | ||
| from jax._src.interpreters import partial_eval as pe | ||
| from jax._src.interpreters.partial_eval import ( | ||
| DynamicJaxprTracer, | ||
| TracingEqn, | ||
| compute_on, | ||
| xla_metadata_lib, | ||
| ) | ||
| from jax._src.lax import lax | ||
| from packaging.version import Version | ||
| except ModuleNotFoundError: # pragma: no cover | ||
| has_jax = False # pragma: no cover | ||
|
|
||
|
|
||
| def _add_make_eqn_helper(): | ||
| """ | ||
| Return a make_eqn helper method to DynamicJaxprTrace. | ||
|
|
||
| This helper properly creates TracingEqn objects, which is needed for JAX 0.7.0 | ||
| compatibility. This is based on Catalyst's approach to the same issue. | ||
|
|
||
| Returns: | ||
| tuple: (DynamicJaxprTrace, "make_eqn", make_eqn). | ||
| """ | ||
|
|
||
| def make_eqn( | ||
| self, | ||
| in_tracers: list, | ||
| out_avals_or_tracers: list, | ||
| primitive, | ||
| params: dict, | ||
| effects: set, | ||
| source_info=None, | ||
| ctx=None, | ||
| ): | ||
| """Create a tracing equation properly. | ||
|
|
||
| Args: | ||
| in_tracers: Input tracers | ||
| out_avals_or_tracers: Output abstract values OR output tracers (with vars already created) | ||
| primitive: The primitive operation | ||
| params: Parameters for the primitive | ||
| effects: Effects of the operation | ||
| source_info: Source information for debugging | ||
| ctx: JaxprEqnContext (created if not provided) | ||
|
|
||
| Returns: | ||
| (eqn, out_tracers): TracingEqn and output tracers | ||
| """ | ||
| source_info = source_info or source_info_util.new_source_info() | ||
| ctx = ctx or JaxprEqnContext( | ||
| compute_on.current_compute_type(), | ||
| jax_config.threefry_partitionable.value, | ||
| xla_metadata_lib.current_xla_metadata(), | ||
| ) | ||
|
|
||
| # Normalize out_avals to a list | ||
| if not isinstance(out_avals_or_tracers, (list, tuple)): | ||
| out_avals = [out_avals_or_tracers] | ||
| else: | ||
| out_avals = out_avals_or_tracers | ||
|
|
||
| outvars = [self.frame.newvar(aval) for aval in out_avals] | ||
|
|
||
| if jax_config.enable_checks.value: | ||
| assert all(isinstance(x, DynamicJaxprTracer) for x in in_tracers) | ||
| assert all(isinstance(v, Var) for v in outvars) | ||
|
|
||
| eqn = TracingEqn(list(in_tracers), outvars, primitive, params, effects, source_info, ctx) | ||
|
|
||
| # Create output tracers - manually create DynamicJaxprTracer objects | ||
| # We pass the equation as the parent parameter (4th argument to __init__) | ||
| out_tracers = [ | ||
| DynamicJaxprTracer(self, aval, v, source_info, eqn) | ||
| for aval, v in zip(out_avals, outvars) | ||
| ] | ||
|
|
||
| return eqn, out_tracers | ||
|
|
||
| return (pe.DynamicJaxprTrace, "make_eqn", make_eqn) | ||
|
|
||
|
|
||
| def _patch_dyn_shape_staging_rule(): | ||
| """ | ||
| Return _dyn_shape_staging_rule patch to fix dynamic shape handling. | ||
|
|
||
| The bug in JAX 0.7.0's lax/lax.py lines 267-275 is that it uses: | ||
| - pe.new_jaxpr_eqn instead of proper TracingEqn creation | ||
|
|
||
| This causes an AssertionError when add_eqn expects a TracingEqn but gets a JaxprEqn. | ||
| This affects all array creation operations with traced dimensions like jnp.arange, | ||
| jnp.ones, jnp.zeros, etc. | ||
|
|
||
| The fix uses the make_eqn helper to properly create TracingEqn objects. | ||
|
|
||
| Returns: | ||
| list: List of patch tuples. | ||
| """ | ||
|
|
||
| def patched_dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args, **params): | ||
| """Patched version of _dyn_shape_staging_rule using make_eqn helper.""" | ||
| # Check if we have a StagingJaxprTrace (has counter) or DynamicJaxprTrace (no counter) | ||
| if hasattr(trace, "counter"): | ||
| # StagingJaxprTrace path - use new_eqn_recipe (original JAX approach) | ||
| var = trace.frame.newvar(out_aval) | ||
| out_tracer = pe.DynamicJaxprTracer(trace, out_aval, var, source_info) | ||
| eqn = pe.new_eqn_recipe( | ||
| trace, args, [out_tracer], prim, params, core.no_effects, source_info | ||
| ) | ||
| out_tracer.recipe = eqn | ||
| return out_tracer | ||
|
|
||
| # DynamicJaxprTrace path - use make_eqn helper | ||
| eqn, out_tracers = trace.make_eqn( | ||
| args, out_aval, prim, params, core.no_effects, source_info | ||
| ) | ||
| trace.frame.add_eqn(eqn) | ||
| # Return single tracer (not list) since out_aval is a single value | ||
| return out_tracers[0] | ||
|
|
||
| # Return just the core patch - the wrappers will call the patched version | ||
| return [ | ||
| (lax, "_dyn_shape_staging_rule", patched_dyn_shape_staging_rule), | ||
| ] | ||
|
|
||
|
|
||
| def _patch_pjit_staging_rule(): | ||
| """ | ||
| Return pjit_staging_rule patch to fix dynamic shape handling. | ||
|
|
||
| The bug in JAX 0.7.0's pjit.py lines 1894-1898 is that it uses: | ||
| - core.new_jaxpr_eqn instead of pe.new_eqn_recipe | ||
| - arg.var instead of accessing the correct tracer value | ||
|
|
||
| This causes an AssertionError when add_eqn expects a TracingEqn but gets a JaxprEqn. | ||
|
|
||
| Returns: | ||
| list: List of patch tuples. | ||
| """ | ||
| # Store the original function | ||
| original_staging_rule = pjit.pjit_staging_rule | ||
|
|
||
| def patched_pjit_staging_rule(trace, source_info, *args, **params): | ||
| """Patched version of pjit_staging_rule with dynamic shape fixes.""" | ||
| # Use the original implementation for most cases | ||
| if not jax_config.dynamic_shapes.value: | ||
| return original_staging_rule(trace, source_info, *args, **params) | ||
|
|
||
| # Check if we're in the inline path | ||
| if ( | ||
| params["inline"] | ||
| and all(isinstance(i, pjit.UnspecifiedValue) for i in params["in_shardings"]) | ||
| and all(isinstance(o, pjit.UnspecifiedValue) for o in params["out_shardings"]) | ||
| and all(i is None for i in params["in_layouts"]) | ||
| and all(o is None for o in params["out_layouts"]) | ||
| ): | ||
| # Use original for inline path | ||
| return original_staging_rule(trace, source_info, *args, **params) | ||
|
|
||
| jaxpr = params["jaxpr"] | ||
|
|
||
| # This is the dynamic shapes path that needs fixing | ||
| jaxpr, in_fwd, out_shardings, out_layouts = pjit._pjit_forwarding( | ||
| jaxpr, params["out_shardings"], params["out_layouts"] | ||
| ) | ||
| params = { | ||
| **params, | ||
| "jaxpr": jaxpr, | ||
| "out_shardings": out_shardings, | ||
| "out_layouts": out_layouts, | ||
| } | ||
|
|
||
| # Fix 1: Use list instead of map to create outvars | ||
| outvars = [trace.frame.newvar(aval) for aval in pjit._out_type(jaxpr)] | ||
|
|
||
| # DynamicJaxprTrace (eval_jaxpr path) vs StagingJaxprTrace need different approaches | ||
| # DynamicJaxprTrace doesn't have 'counter' attribute | ||
| if not hasattr(trace, "counter"): | ||
| # For DynamicJaxprTrace: Use make_eqn helper (from _add_make_eqn_helper) | ||
| # Pass avals (not tracers) to make_eqn | ||
| in_tracers = [core.get_referent(arg) for arg in args] | ||
| out_avals = [v.aval for v in outvars] | ||
| eqn, out_tracers = trace.make_eqn( | ||
| in_tracers, out_avals, pjit.jit_p, params, jaxpr.effects, source_info | ||
| ) | ||
| trace.frame.add_eqn(eqn) | ||
| else: | ||
| # For StagingJaxprTrace: Use new_eqn_recipe | ||
| eqn = pe.new_eqn_recipe( | ||
| trace, | ||
| args, | ||
| [pe.DynamicJaxprTracer(trace, v.aval, v, source_info) for v in outvars], | ||
| pjit.jit_p, | ||
| params, | ||
| jaxpr.effects, | ||
| source_info, | ||
| ) | ||
|
|
||
| out_tracers = [pe.DynamicJaxprTracer(trace, v.aval, v, source_info) for v in outvars] | ||
| for t in out_tracers: | ||
| t.recipe = eqn | ||
|
|
||
| # Handle forwarding | ||
| out_tracers_ = iter(out_tracers) | ||
| out_tracers = [args[f] if isinstance(f, int) else next(out_tracers_) for f in in_fwd] | ||
| assert next(out_tracers_, None) is None | ||
|
|
||
| return out_tracers | ||
|
|
||
| return [ | ||
| (pjit, "pjit_staging_rule", patched_pjit_staging_rule), | ||
| (pe.custom_staging_rules, "__dict_item__", pjit.jit_p, patched_pjit_staging_rule), | ||
| ] | ||
|
|
||
|
|
||
| def get_jax_patches(): | ||
| """Get patch tuples for use with Patcher context manager. | ||
|
|
||
| Returns a tuple of (obj, attr, new_value) tuples that can be passed to Patcher. | ||
| These patches fix JAX 0.7.0+ compatibility issues for dynamic shapes and pjit. | ||
|
|
||
| Returns: | ||
| tuple: Patch tuples for Patcher, or empty tuple if patches not needed | ||
|
|
||
| Example: | ||
| >>> from pennylane.capture.patching import Patcher | ||
| >>> from pennylane.capture.jax_patches import get_jax_patches | ||
| >>> with Patcher(*get_jax_patches()): | ||
| ... # JAX operations with patches applied | ||
| ... jaxpr = jax.make_jaxpr(my_function)(args) | ||
| """ | ||
| if not has_jax: | ||
| return () | ||
|
|
||
| patches = [] | ||
|
|
||
| # Get all patches from the helper functions | ||
| patches.append(_add_make_eqn_helper()) | ||
| patches.extend(_patch_dyn_shape_staging_rule()) | ||
| patches.extend(_patch_pjit_staging_rule()) | ||
|
|
||
| return tuple(patches) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are actually all for dynamic shapes. In #8654 we can see that they are only required to be active within the dynamic shape context