[Mlir-commits] [mlir] 778494a - [mlir][bufferization] Add bufferization.dealloc canonicalizer to remove unused alloc-dealloc pairs

Martin Erhart llvmlistbot at llvm.org
Mon Aug 28 01:04:20 PDT 2023


Author: Martin Erhart
Date: 2023-08-28T08:04:02Z
New Revision: 778494ae07ec4d2429d93e8b969f9865b82423c0

URL: https://github.com/llvm/llvm-project/commit/778494ae07ec4d2429d93e8b969f9865b82423c0
DIFF: https://github.com/llvm/llvm-project/commit/778494ae07ec4d2429d93e8b969f9865b82423c0.diff

LOG: [mlir][bufferization] Add bufferization.dealloc canonicalizer to remove unused alloc-dealloc pairs

Deallocation operations where the allocated value is the 'memref' and
'retained' list are currently not supported. This is because when values
are in the retained list, they typically have a use-site at a later
point and another deallocation op exists at that later point to free the
memref then. There alrady exists a canonicalization pattern in the
buffer deallocation simplification pass that removes the allocated value
from the earlier dealloc because it will never be actually deallocated
in that case and thus does not have to be considered in this new
pattern.

Differential Revision: https://reviews.llvm.org/D158740

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/test/Dialect/Bufferization/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index c8681374ccae11..83427eb7122afd 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -963,13 +963,65 @@ struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
   }
 };
 
+/// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
+/// other user of the allocated value and the allocating operation can be safely
+/// removed. If the same value is present multiple times, this pattern relies on
+/// other canonicalization patterns to remove the duplicate first.
+///
+/// Example:
+/// ```mlir
+/// %alloc = memref.alloc() : memref<2xi32>
+/// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
+/// ```
+/// is canonicalized to
+/// ```mlir
+/// bufferization.dealloc (%arg0 : ...) if (%true)
+/// ```
+struct RemoveAllocDeallocPairWhenNoOtherUsers
+    : public OpRewritePattern<DeallocOp> {
+  using OpRewritePattern<DeallocOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DeallocOp deallocOp,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value> newMemrefs, newConditions;
+    SmallVector<Operation *> toDelete;
+    for (auto [memref, cond] :
+         llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+      if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
+        // Check that it is indeed an allocate effect, that the op has no other
+        // side effects (which would not allow us to remove the op), and that
+        // there are no other users.
+        if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
+            hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
+            memref.hasOneUse()) {
+          toDelete.push_back(allocOp);
+          continue;
+        }
+      }
+
+      newMemrefs.push_back(memref);
+      newConditions.push_back(cond);
+    }
+
+    if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
+                                      rewriter)))
+      return failure();
+
+    for (Operation *op : toDelete)
+      rewriter.eraseOp(op);
+
+    return success();
+  }
+};
+
 } // anonymous namespace
 
 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
   results.add<DeallocRemoveDuplicateDeallocMemrefs,
               DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
-              EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc>(context);
+              EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
+              RemoveAllocDeallocPairWhenNoOtherUsers>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 0f0ac678d25110..12f13743febb73 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -323,12 +323,12 @@ func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2x
 
 // -----
 
-func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>) {
+func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>) -> memref<2xi32> {
   %alloc = memref.alloc() : memref<2xi32>
   %base0, %size0, %stride0, %offset0 = memref.extract_strided_metadata %alloc : memref<2xi32> -> memref<i32>, index, index, index
   %base1, %size1, %stride1, %offset1 = memref.extract_strided_metadata %arg3 : memref<2xi32> -> memref<i32>, index, index, index
   bufferization.dealloc (%base0, %arg0, %base1 : memref<i32>, memref<2xi32>, memref<i32>) if (%arg1, %arg2, %arg2)
-  return
+  return %alloc : memref<2xi32>
 }
 
 // CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc
@@ -337,3 +337,17 @@ func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1,
 //  CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG3]] :
 //  CHECK-NEXT: bufferization.dealloc ([[ALLOC]], [[ARG0]], [[BASE]] : memref<2xi32>, memref<2xi32>, memref<i32>) if ([[ARG1]], [[ARG2]], [[ARG2]])
 //  CHECK-NEXT: return
+
+// -----
+
+func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>) {
+  %true = arith.constant true
+  %alloc = memref.alloc() : memref<2xi32>
+  bufferization.dealloc (%alloc, %arg0 : memref<2xi32>, memref<2xi32>) if (%true, %true)
+  return
+}
+
+// CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc
+//  CHECK-SAME:([[ARG0:%.+]]: memref<2xi32>)
+//   CHECK-NOT: memref.alloc(
+//       CHECK: bufferization.dealloc ([[ARG0]] : memref<2xi32>) if (%true


        


More information about the Mlir-commits mailing list