Skip to content

Conversation

@sakupan102
Copy link
Contributor

Previously, vector.scatter didn’t support tensor outputs. Upstream changes added that support, so DecomposeMapScatter now handles the tensor case as well.

Closes #22697

Previously, vector.scatter didn’t support tensor outputs. Upstream changes added that support, so DecomposeMapScatter now handles the tensor case as well.

Closes iree-org#22697

Signed-off-by: Ryutaro Okada <[email protected]>
Signed-off-by: Ryutaro Okada <[email protected]>
@hanhanW
Copy link
Contributor

hanhanW commented Dec 18, 2025

Hi, thanks for the patch. @Max191 is currently on vacation, and it's better to have his eyes on this. IMO, this kind of work is better to have a concrete example and see if it truly fixes anything. We'll need to support map_scatter op better on CPU side, where all the pipelines are tensor-based pipeline. Let's see if it will help or not. For now, would you mind to wait until people are back from vacations?

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am with adding this support now, since it is already mostly implemented in this PR. It still won't have any actual use cases yet, but we could try using on CPU and seeing if there are any issues e2e.

I left a few comments on the implementation, but looks good so far!

Comment on lines -164 to -165
SmallVector<OpFoldResult> outputSizes =
memref::getMixedSizes(rewriter, loc, mapScatterOp.getOutput());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the output sizes, there is a nice utility function to handle both memref and tensor types:

SmallVector<OpFoldResult> getDims(OpBuilder &builder, Location loc, Value v);

Comment on lines +183 to +184
flatOutput = tensor::CollapseShapeOp::create(
rewriter, loc, mapScatterOp.getOutput(), reassociations);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now, since we don't have any concrete use cases yet, but adding collapse_shape ops on tensors can have more implications than it does for memrefs. It may be fine, but bufferization can often get tripped up by reshapes, which could end up creating extra large allocations. This is part of the challenge with vectorizing map_scatter on tensors, and we need to try it e2e to see if it works or not.

That said, I'm not sure what the best way to represent it will be, so we can reevaluate once we are able to test this in an e2e path.

Comment on lines +392 to +393
tensor::ExpandShapeOp::create(rewriter, loc, mapScatterOp.getOutputType(),
scatterOp.getResult(), reassociations);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this will only work for static shapes. To handle dynamic shapes too, you need to also pass a list of OpFoldResult for the resulting output shape, like what is done here:

auto expandedInitArg = tensor::ExpandShapeOp::create(
rewriter, loc, initArg.getType(), initArg, reassociations, initSizes);

You can get the output shape from the original map_scatter output tensor.

// CHECK: %[[EXTRACT_3:.+]] = vector.extract %[[INPUT]][3] : vector<16xf32> from vector<4x16xf32>
// CHECK: vector.store %[[EXTRACT_3]], %[[FLAT_OUTPUT]][%[[EXTRACT_IDX_3]]] : memref<64xf32>, vector<16xf32>

// -----
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add a test for dynamic shapes.

Comment on lines -391 to +424
if (!isMaskForwardSlice && isUnitFunctionOfInnermostInputIdx) {
if (!isMaskForwardSlice && isUnitFunctionOfInnermostInputIdx &&
mapScatterOp.hasPureBufferSemantics()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't make a difference whether it is memref or tensor here. Can you support this case too (and add tests for it)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support docomposition for iree_linalg_ext.map_scatter with tensor output

3 participants