Skip to content

[Codegen] add FoldExtractSliceOfFillThroughBlockArg pattern to TileAndDistributeToWorkgroups#22983

Closed
bangtianliu wants to merge 1 commit intoiree-org:mainfrom
bangtianliu:fill_extract_pattern
Closed

[Codegen] add FoldExtractSliceOfFillThroughBlockArg pattern to TileAndDistributeToWorkgroups#22983
bangtianliu wants to merge 1 commit intoiree-org:mainfrom
bangtianliu:fill_extract_pattern

Conversation

@bangtianliu
Copy link
Copy Markdown
Contributor

@bangtianliu bangtianliu commented Dec 26, 2025

Problem

In split reduction lowering (e.g., iree_linalg_ext.arg_compare), linalg.fill operations are defined outside nested scf.forall loops. The existing FoldExtractSliceOfBroadcastPattern folds intermediate broadcast + extract_slice ops, exposing the fill result to the inner loop's shared_outs. However, inside the inner loop, tensor.extract_slice on the block argument still needs the fill to be materialized locally for correct bufferization.

%fill = linalg.fill ins(%cst) outs(%empty) -> tensor<4x1xf16>          // Outside loops
scf.forall (%iv) shared_outs(%arg1 = %out_tensor) {                    // Outer loop  
  // After FoldExtractSliceOfBroadcastPattern: broadcast+extract folded away  
  // %fill is now directly available    
  scf.forall ... shared_outs(%arg5 = %fill) {                          // Inner loop    
      %slice5 = tensor.extract_slice %arg5[%i, %j] [1, ?] [1, 1]         // From block arg    
      // %slice5 traces back to %fill but needs initialization inside this loop  
  }
}

This causes bufferization issues because the fill is outside the inner loop, but each workgroup needs its own properly initialized tile.

Solution

This PR adds FoldExtractSliceOfFillThroughBlockArgPattern which traces block arguments back to their defining linalg.fill and creates a new fill on the extracted slice inside the loop:

// Before
%fill = linalg.fill ins(%cst) outs(%empty) -> tensor<4x16xf16>
scf.forall ... shared_outs(%out = %fill) {
  %slice = tensor.extract_slice %out[%iv0, %iv1] [1, 8] [1, 1]
  // %slice used without initialization inside loop
}

// After
scf.forall ... shared_outs(%out = %empty) {
  %slice = tensor.extract_slice %out[%iv0, %iv1] [1, 8] [1, 1]
  %filled = linalg.fill ins(%cst) outs(%slice) -> tensor<1x8xf16>
  // %filled properly initialized inside loop
}

How it works with FoldExtractSliceOfBroadcastPattern

For the nested loop case in split reduction, the patterns are applied together:

  1. First, the existing FoldExtractSliceOfBroadcastPattern simplifies extract_slice(broadcast(fill)) → directly exposes the fill result to the inner forall's shared_outs
  2. Then, the new FoldExtractSliceOfFillThroughBlockArgPattern traces the inner loop's block argument back to the fill and creates a fill inside the innermost loop
// Before 
%fill = linalg.fill ins(%cst) outs(%empty)           // Outside all loops
scf.forall ... {
  %broadcast = linalg.broadcast ins(%fill) ...
  %slice0 = tensor.extract_slice %broadcast ...      
  scf.forall shared_outs(%arg5 = %slice0) {          // Traces to fill via broadcast
    %slice5 = tensor.extract_slice %arg5 ...         // Not initialized inside
    // Race condition: multiple workgroups write without proper init
  }
}
 
 // After 
 scf.forall ... shared_outs(%arg1 = %empty_out) {     // Outer loop
  scf.forall shared_outs(%arg5 = %empty_tile) {      // Inner loop  
    %slice5 = tensor.extract_slice %arg5 ...
    %filled5 = linalg.fill ins(%cst) outs(%slice5)   // Fill in innermost loop
    // Each workgroup tile properly initialized
  }
}

This ensures fill initialization happens at the innermost level where computation occurs, enabling correct bufferization without race conditions.

ci-extra: test_torch

@bangtianliu bangtianliu force-pushed the fill_extract_pattern branch 3 times, most recently from 45f6165 to a5b2ab5 Compare December 28, 2025 03:51
@bangtianliu bangtianliu marked this pull request as draft December 29, 2025 18:35
@bangtianliu bangtianliu force-pushed the fill_extract_pattern branch 2 times, most recently from cd6ffd3 to 8dd5f67 Compare December 30, 2025 20:03
@bangtianliu
Copy link
Copy Markdown
Contributor Author

I verified this PR's changes by commenting out the pattern in FormSplitReductionDispatches https://github.com/iree-org/iree/pull/22394/files#diff-fde5e3279ddcfe46ae2a63b157ddaeae71858f379caa43e2ebc9d57bb5265531R211 in this draft PR: #22394

With the changes from this PR and #22953, O3 can now be used for the llama 8b fp16 quality tests added in #22379.

@bangtianliu bangtianliu marked this pull request as ready for review January 2, 2026 17:06
@bangtianliu bangtianliu requested a review from kuhar January 2, 2026 17:06
…dDistributeToWorkgroups

ci-extra: test_torch
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
@bangtianliu bangtianliu force-pushed the fill_extract_pattern branch from 8dd5f67 to ae969b8 Compare January 2, 2026 18:14
@bangtianliu bangtianliu marked this pull request as draft January 2, 2026 18:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant