Skip to content

Commit 0400b9a

Browse files
authored
[mlir]: Add handling of escaped memrefs to erase_dead_alloc_and_stores (#167255)
Patch updates transform.memref.erase_dead_alloc_and_stores to not delete escaped allocations.
1 parent dacd2f9 commit 0400b9a

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
140140
std::vector<Operation *> opUses;
141141
for (OpOperand &use : op->getUses()) {
142142
Operation *useOp = use.getOwner();
143+
// Use escaped the scope
144+
if (useOp->mightHaveTrait<OpTrait::IsTerminator>())
145+
return false;
143146
if (isa<memref::DeallocOp>(useOp) ||
144147
(useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
145148
!mlir::hasEffect<MemoryEffects::Read>(useOp)) ||

mlir/test/Dialect/MemRef/transform-ops.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,23 @@ module attributes {transform.with_named_sequence} {
306306

307307
// -----
308308

309+
// CHECK-LABEL: func.func @dead_alloc_escaped
310+
func.func @dead_alloc_escaped() -> memref<8x64xf32, 3> {
311+
// CHECK: %{{.+}} = memref.alloc
312+
%0 = memref.alloc() : memref<8x64xf32, 3>
313+
return %0 : memref<8x64xf32, 3>
314+
}
315+
316+
module attributes {transform.with_named_sequence} {
317+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
318+
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
319+
transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
320+
transform.yield
321+
}
322+
}
323+
324+
// -----
325+
309326
// CHECK-LABEL: func.func @dead_alloc
310327
func.func @dead_alloc() {
311328
// CHECK-NOT: %{{.+}} = memref.alloc

0 commit comments

Comments
 (0)