Add Shardy dialect support for JAX 0.8.2+ compatibility #22930
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
JAX 0.8.2+ uses the Shardy partitioner by default (
jax_use_shardy_partitioner=True), which emits MLIR bytecode containing thesdy(Shardy) dialect. Without proper dialect registration, IREE fails with:This PR adds full Shardy dialect support to IREE.
Changes
third_party/shardy): Add openxla/shardy as a git submodulebuild_tools/third_party/shardy/): CMake build files since upstream Shardy only has Bazelcompiler/plugins/input/Shardy/):StripShardyDialectpass to remove sdy ops/attributes for single-device executioniree_compiler_register_plugin()to ensure symbols are includedIREE_INPUT_SHARDY(ON by default)test_shardy.pyverifies JAX works with Shardy enabledTechnical Details
The implementation follows IREE's plugin architecture patterns:
ShardySessionplugin class withDefaultActivatedpolicymlir::sdy::registerAllDialects()sdy.shardingattributes (metadata-only for single-device)libIREECompilerTest plan
test_shardy.pyadded to test runner🤖 Generated with Claude Code