-
Notifications
You must be signed in to change notification settings - Fork 817
[LinalgExt] Enable tensor map_scatter lowering #22931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Previously, vector.scatter didn’t support tensor outputs. Upstream changes added that support, so DecomposeMapScatter now handles the tensor case as well. Closes iree-org#22697 Signed-off-by: Ryutaro Okada <[email protected]>
Signed-off-by: Ryutaro Okada <[email protected]>
|
Hi, thanks for the patch. @Max191 is currently on vacation, and it's better to have his eyes on this. IMO, this kind of work is better to have a concrete example and see if it truly fixes anything. We'll need to support map_scatter op better on CPU side, where all the pipelines are tensor-based pipeline. Let's see if it will help or not. For now, would you mind to wait until people are back from vacations? |
Max191
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am with adding this support now, since it is already mostly implemented in this PR. It still won't have any actual use cases yet, but we could try using on CPU and seeing if there are any issues e2e.
I left a few comments on the implementation, but looks good so far!
| SmallVector<OpFoldResult> outputSizes = | ||
| memref::getMixedSizes(rewriter, loc, mapScatterOp.getOutput()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the output sizes, there is a nice utility function to handle both memref and tensor types:
| SmallVector<OpFoldResult> getDims(OpBuilder &builder, Location loc, Value v); |
| flatOutput = tensor::CollapseShapeOp::create( | ||
| rewriter, loc, mapScatterOp.getOutput(), reassociations); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine for now, since we don't have any concrete use cases yet, but adding collapse_shape ops on tensors can have more implications than it does for memrefs. It may be fine, but bufferization can often get tripped up by reshapes, which could end up creating extra large allocations. This is part of the challenge with vectorizing map_scatter on tensors, and we need to try it e2e to see if it works or not.
That said, I'm not sure what the best way to represent it will be, so we can reevaluate once we are able to test this in an e2e path.
| tensor::ExpandShapeOp::create(rewriter, loc, mapScatterOp.getOutputType(), | ||
| scatterOp.getResult(), reassociations); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this will only work for static shapes. To handle dynamic shapes too, you need to also pass a list of OpFoldResult for the resulting output shape, like what is done here:
iree/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
Lines 787 to 788 in f4b596d
| auto expandedInitArg = tensor::ExpandShapeOp::create( | |
| rewriter, loc, initArg.getType(), initArg, reassociations, initSizes); |
You can get the output shape from the original map_scatter output tensor.
| // CHECK: %[[EXTRACT_3:.+]] = vector.extract %[[INPUT]][3] : vector<16xf32> from vector<4x16xf32> | ||
| // CHECK: vector.store %[[EXTRACT_3]], %[[FLAT_OUTPUT]][%[[EXTRACT_IDX_3]]] : memref<64xf32>, vector<16xf32> | ||
|
|
||
| // ----- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also add a test for dynamic shapes.
| if (!isMaskForwardSlice && isUnitFunctionOfInnermostInputIdx) { | ||
| if (!isMaskForwardSlice && isUnitFunctionOfInnermostInputIdx && | ||
| mapScatterOp.hasPureBufferSemantics()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't make a difference whether it is memref or tensor here. Can you support this case too (and add tests for it)?
Previously, vector.scatter didn’t support tensor outputs. Upstream changes added that support, so DecomposeMapScatter now handles the tensor case as well.
Closes #22697