[Mlir-commits] [mlir] b0688ed - [mlir][bufferization] Add DeallocOp canonicalizer to remove memrefs also present in the retained list

Martin Erhart llvmlistbot at llvm.org
Fri Jul 28 09:41:30 PDT 2023


Author: Martin Erhart
Date: 2023-07-28T16:41:03Z
New Revision: b0688ed0dcb67b6582df3cbef0910257bb1f75f2

URL: https://github.com/llvm/llvm-project/commit/b0688ed0dcb67b6582df3cbef0910257bb1f75f2
DIFF: https://github.com/llvm/llvm-project/commit/b0688ed0dcb67b6582df3cbef0910257bb1f75f2.diff

LOG: [mlir][bufferization] Add DeallocOp canonicalizer to remove memrefs also present in the retained list

Since memrefs in the retained list will never be deallocated, we can remove them from the list of memrefs to be deallocated. If the list of memrefs to deallocate becomes empty, we can just delete the dealloc operation.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D156186

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/test/Dialect/Bufferization/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 6f475735d633c5..6e9610b5e55830 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -791,11 +791,19 @@ struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
     // Unique memrefs to be deallocated.
+    DenseSet<Value> retained(deallocOp.getRetained().begin(),
+                             deallocOp.getRetained().end());
     DenseMap<Value, unsigned> memrefToCondition;
     SmallVector<Value> newMemrefs, newConditions, newRetained;
-    SmallVector<unsigned> resultIndices;
-    for (auto [memref, cond] :
-         llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+    SmallVector<int32_t> resultIndices(deallocOp.getMemrefs().size(), -1);
+    for (auto [i, memref, cond] :
+         llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+      if (retained.contains(memref)) {
+        rewriter.replaceAllUsesWith(deallocOp.getResult(i),
+                                    deallocOp.getConditions()[i]);
+        continue;
+      }
+
       if (memrefToCondition.count(memref)) {
         // If the dealloc conditions don't match, we need to make sure that the
         // dealloc happens on the union of cases.
@@ -808,7 +816,7 @@ struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
         newMemrefs.push_back(memref);
         newConditions.push_back(cond);
       }
-      resultIndices.push_back(memrefToCondition[memref]);
+      resultIndices[i] = memrefToCondition[memref];
     }
 
     // Unique retained values
@@ -831,19 +839,38 @@ struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
     auto newDealloc = rewriter.create<DeallocOp>(deallocOp.getLoc(), newMemrefs,
                                                  newConditions, newRetained);
     for (auto [i, newIdx] : llvm::enumerate(resultIndices))
-      rewriter.replaceAllUsesWith(deallocOp.getResult(i),
-                                  newDealloc.getResult(newIdx));
+      if (newIdx != -1)
+        rewriter.replaceAllUsesWith(deallocOp.getResult(i),
+                                    newDealloc.getResult(newIdx));
 
     rewriter.eraseOp(deallocOp);
     return success();
   }
 };
 
+/// Erase deallocation operations where the variadic list of memrefs to
+/// deallocate is emtpy. Example:
+/// ```mlir
+/// bufferization.dealloc retain (%arg0: memref<2xi32>)
+/// ```
+struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
+  using OpRewritePattern<DeallocOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DeallocOp deallocOp,
+                                PatternRewriter &rewriter) const override {
+    if (deallocOp.getMemrefs().empty()) {
+      rewriter.eraseOp(deallocOp);
+      return success();
+    }
+    return failure();
+  }
+};
+
 } // anonymous namespace
 
 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<DeallocRemoveDuplicates>(context);
+  results.add<DeallocRemoveDuplicates, EraseEmptyDealloc>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index d8088b627d23a8..90231f36623024 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -282,15 +282,30 @@ func.func @dealloc_canonicalize_clone_removal(%arg0: memref<?xindex>) -> memref<
 
 // -----
 
-func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>) -> (i1, i1, i1, i1, i1) {
-  %0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg4, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>)
+func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>, %arg5: memref<2xi32>) -> (i1, i1, i1, i1, i1) {
+  %0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg5, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>)
   %1:2 = bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2)
   return %0#0, %0#1, %0#2, %1#0, %1#1 : i1, i1, i1, i1, i1
 }
 
 // CHECK-LABEL: func @dealloc_canonicalize_duplicates
-//  CHECK-SAME:  ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>, [[ARG4:%.+]]: memref<2xi32>)
-//  CHECK-NEXT:   [[V0:%.+]]:2 = bufferization.dealloc ([[ARG4]], [[ARG0]] : memref<2xi32>, memref<2xi32>) if ([[ARG1]], [[ARG1]]) retain ([[ARG3]], [[ARG4]] : memref<2xi32>, memref<2xi32>)
+//  CHECK-SAME:  ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>, [[ARG4:%.+]]: memref<2xi32>, [[ARG5:%.+]]: memref<2xi32>)
+//  CHECK-NEXT:   [[V0:%.+]]:2 = bufferization.dealloc ([[ARG4]], [[ARG0]] : memref<2xi32>, memref<2xi32>) if ([[ARG1]], [[ARG1]]) retain ([[ARG3]], [[ARG5]] : memref<2xi32>, memref<2xi32>)
 //  CHECK-NEXT:   [[NEW_COND:%.+]] = arith.ori [[ARG1]], [[ARG2]] : i1
 //  CHECK-NEXT:   [[V1:%.+]] = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[NEW_COND]])
 //  CHECK-NEXT:   return [[V0]]#0, [[V0]]#1, [[V0]]#1, [[V1]], [[V1]] :
+
+// -----
+
+func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1) {
+  %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>)
+  %1:2 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
+  bufferization.dealloc
+  bufferization.dealloc retain (%arg0 : memref<2xi32>)
+  return %0, %1#0, %1#1 : i1, i1, i1
+}
+
+// CHECK-LABEL: func @dealloc_canonicalize_retained_and_deallocated
+//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
+//  CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
+//  CHECK-NEXT: return [[ARG1]], [[ARG1]], [[V0]] :


        


More information about the Mlir-commits mailing list