[Mlir-commits] [mlir] 778494a - [mlir][bufferization] Add bufferization.dealloc canonicalizer to remove unused alloc-dealloc pairs
Martin Erhart
llvmlistbot at llvm.org
Mon Aug 28 01:04:20 PDT 2023
Author: Martin Erhart
Date: 2023-08-28T08:04:02Z
New Revision: 778494ae07ec4d2429d93e8b969f9865b82423c0
URL: https://github.com/llvm/llvm-project/commit/778494ae07ec4d2429d93e8b969f9865b82423c0
DIFF: https://github.com/llvm/llvm-project/commit/778494ae07ec4d2429d93e8b969f9865b82423c0.diff
LOG: [mlir][bufferization] Add bufferization.dealloc canonicalizer to remove unused alloc-dealloc pairs
Deallocation operations where the allocated value is the 'memref' and
'retained' list are currently not supported. This is because when values
are in the retained list, they typically have a use-site at a later
point and another deallocation op exists at that later point to free the
memref then. There alrady exists a canonicalization pattern in the
buffer deallocation simplification pass that removes the allocated value
from the earlier dealloc because it will never be actually deallocated
in that case and thus does not have to be considered in this new
pattern.
Differential Revision: https://reviews.llvm.org/D158740
Added:
Modified:
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/test/Dialect/Bufferization/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index c8681374ccae11..83427eb7122afd 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -963,13 +963,65 @@ struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
}
};
+/// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
+/// other user of the allocated value and the allocating operation can be safely
+/// removed. If the same value is present multiple times, this pattern relies on
+/// other canonicalization patterns to remove the duplicate first.
+///
+/// Example:
+/// ```mlir
+/// %alloc = memref.alloc() : memref<2xi32>
+/// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
+/// ```
+/// is canonicalized to
+/// ```mlir
+/// bufferization.dealloc (%arg0 : ...) if (%true)
+/// ```
+struct RemoveAllocDeallocPairWhenNoOtherUsers
+ : public OpRewritePattern<DeallocOp> {
+ using OpRewritePattern<DeallocOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DeallocOp deallocOp,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> newMemrefs, newConditions;
+ SmallVector<Operation *> toDelete;
+ for (auto [memref, cond] :
+ llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+ if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
+ // Check that it is indeed an allocate effect, that the op has no other
+ // side effects (which would not allow us to remove the op), and that
+ // there are no other users.
+ if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
+ hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
+ memref.hasOneUse()) {
+ toDelete.push_back(allocOp);
+ continue;
+ }
+ }
+
+ newMemrefs.push_back(memref);
+ newConditions.push_back(cond);
+ }
+
+ if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
+ rewriter)))
+ return failure();
+
+ for (Operation *op : toDelete)
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
} // anonymous namespace
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DeallocRemoveDuplicateDeallocMemrefs,
DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
- EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc>(context);
+ EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
+ RemoveAllocDeallocPairWhenNoOtherUsers>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 0f0ac678d25110..12f13743febb73 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -323,12 +323,12 @@ func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2x
// -----
-func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>) {
+func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>) -> memref<2xi32> {
%alloc = memref.alloc() : memref<2xi32>
%base0, %size0, %stride0, %offset0 = memref.extract_strided_metadata %alloc : memref<2xi32> -> memref<i32>, index, index, index
%base1, %size1, %stride1, %offset1 = memref.extract_strided_metadata %arg3 : memref<2xi32> -> memref<i32>, index, index, index
bufferization.dealloc (%base0, %arg0, %base1 : memref<i32>, memref<2xi32>, memref<i32>) if (%arg1, %arg2, %arg2)
- return
+ return %alloc : memref<2xi32>
}
// CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc
@@ -337,3 +337,17 @@ func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1,
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG3]] :
// CHECK-NEXT: bufferization.dealloc ([[ALLOC]], [[ARG0]], [[BASE]] : memref<2xi32>, memref<2xi32>, memref<i32>) if ([[ARG1]], [[ARG2]], [[ARG2]])
// CHECK-NEXT: return
+
+// -----
+
+func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>) {
+ %true = arith.constant true
+ %alloc = memref.alloc() : memref<2xi32>
+ bufferization.dealloc (%alloc, %arg0 : memref<2xi32>, memref<2xi32>) if (%true, %true)
+ return
+}
+
+// CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc
+// CHECK-SAME:([[ARG0:%.+]]: memref<2xi32>)
+// CHECK-NOT: memref.alloc(
+// CHECK: bufferization.dealloc ([[ARG0]] : memref<2xi32>) if (%true
More information about the Mlir-commits
mailing list