Add full broadcasting support to LayerNormalization and RMSNormalization #26613
+1,290
−24
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR adds full and spec-compliant broadcasting support to both LayerNormalization and RMSNormalization.
Previously, onnxruntime supported only a partial set of broadcasting cases (based on the logic introduced in this PR #23297 ).
That implementation handled several cases but did not cover all valid broadcasting scenarios.
This PR introduces a complete generic broadcasting path, following the ONNX specification rules.
The previous implementation is preserved as a fast-path and is still used whenever the Scale/Bias shapes match directly.
Main changes:
Extended broadcasting logic in:
layer_norm_helper.h
layer_norm_impl.cc
Added full support for all valid broadcasting configurations of Scale and Bias.
Preserved previous partial logic as a fast-path for exact-match cases.
Added comprehensive tests to:
layer_norm_op_test.cc
rms_norm_op_test.cc
Motivation and Context
Before this fix, some valid ONNX broadcasting shapes were rejected in LayerNormalization and RMSNormalization.
This PR brings the operators into full alignment with the ONNX specification and fixes models that previously failed due to incomplete broadcasting support.
Fixes #26432
Fixes #18184