Skip to content

Conversation

@robtaylor
Copy link

Summary

JAX 0.8.2+ uses the Shardy partitioner by default (jax_use_shardy_partitioner=True), which emits MLIR bytecode containing the sdy (Shardy) dialect. Without proper dialect registration, IREE fails with:

dialect 'sdy' does not implement the bytecode interface

This PR adds full Shardy dialect support to IREE.

Changes

  • Shardy submodule (third_party/shardy): Add openxla/shardy as a git submodule
  • CMake build support (build_tools/third_party/shardy/): CMake build files since upstream Shardy only has Bazel
  • IREE input plugin (compiler/plugins/input/Shardy/):
    • Registers the sdy dialect via IREE's plugin architecture
    • Provides StripShardyDialect pass to remove sdy ops/attributes for single-device execution
    • Integrated via iree_compiler_register_plugin() to ensure symbols are included
  • CMake option: IREE_INPUT_SHARDY (ON by default)
  • Test: test_shardy.py verifies JAX works with Shardy enabled

Technical Details

The implementation follows IREE's plugin architecture patterns:

  • ShardySession plugin class with DefaultActivated policy
  • Dialect registration via mlir::sdy::registerAllDialects()
  • Input conversion pipeline strips sdy.sharding attributes (metadata-only for single-device)
  • Proper plugin registration ensures symbols are linked into libIREECompiler

Test plan

  • Verified locally with JAX 0.8.2 on macOS ARM64
  • Matrix multiplication, JIT compilation, vmap, and grad all work
  • CI: Existing PJRT tests should pass
  • CI: New test_shardy.py added to test runner

🤖 Generated with Claude Code

@robtaylor robtaylor marked this pull request as draft December 17, 2025 02:07
@robtaylor
Copy link
Author

Note: This PR is a draft and has not yet been fully reviewed for upstream contribution.

Areas that may need attention:

  • Code style/formatting alignment with IREE conventions
  • Error handling in the StripShardyDialect pass
  • Test coverage completeness
  • CI integration verification
  • Documentation requirements

The implementation has been tested locally with JAX 0.8.2 on macOS ARM64, but additional review and testing is needed before this is ready for merge.

@robtaylor robtaylor force-pushed the add-shardy-dialect-support branch from 7092ca5 to 046c926 Compare December 17, 2025 04:24
JAX 0.8.2+ uses the Shardy partitioner by default (jax_use_shardy_partitioner=True),
which emits MLIR bytecode containing the sdy (Shardy) dialect. Without proper
dialect registration, IREE fails with "dialect 'sdy' does not implement the
bytecode interface".

This PR adds:

1. Shardy submodule (openxla/shardy) with CMake build support
   - build_tools/third_party/shardy/ contains CMake build files since
     upstream Shardy only has Bazel

2. New IREE input plugin at compiler/plugins/input/Shardy/
   - Registers the sdy dialect via IREE's plugin architecture
   - Provides StripShardyDialect pass to remove sdy ops/attributes for
     single-device execution (sdy ops are metadata-only sharding annotations)

3. New IREE_INPUT_SHARDY CMake option (ON by default)
   - Enables/disables Shardy dialect support in the compiler

4. Test for Shardy integration (test_shardy.py)
   - Verifies JAX works with Shardy enabled on IREE PJRT backends

Technical details:
- ShardySession plugin class with DefaultActivated policy
- Dialect registration via mlir::sdy::registerAllDialects()
- Input conversion pipeline strips sdy.sharding attributes
- Properly integrated via iree_compiler_register_plugin() to ensure
  symbols are included in libIREECompiler

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Signed-off-by: Rob Taylor <[email protected]>
@robtaylor robtaylor force-pushed the add-shardy-dialect-support branch from 046c926 to 5253b62 Compare December 17, 2025 04:27
@robtaylor robtaylor marked this pull request as ready for review December 17, 2025 04:31
@robtaylor
Copy link
Author

I've manually reviewed this PR, and it looks good to me.

- Fix iterator invalidation bug in StripShardyDialect.cpp by collecting
  ops first then erasing in reverse order
- Add warning for unexpected Shardy op patterns that can't be handled
- Implement pass registration in registerShardyInputConversionPasses()
- Add StableHLO dependency check in CMake
- Update copyright years to 2025

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Signed-off-by: Rob Taylor <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant