[Mlir-commits] [mlir] d26eb82 - [mlir][bufferization] DeallocOp canonicalizer removing memrefs that are never deallocated

Martin Erhart llvmlistbot at llvm.org
Wed Aug 2 12:23:20 PDT 2023


Author: Martin Erhart
Date: 2023-08-02T19:12:30Z
New Revision: d26eb822df4131ec0d5ca0b0b91501da53683337

URL: https://github.com/llvm/llvm-project/commit/d26eb822df4131ec0d5ca0b0b91501da53683337
DIFF: https://github.com/llvm/llvm-project/commit/d26eb822df4131ec0d5ca0b0b91501da53683337.diff

LOG: [mlir][bufferization] DeallocOp canonicalizer removing memrefs that are never deallocated

This simplifies the op and avoids unnecessary alias checks introduced during the lowering to memref.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D156807

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/test/Dialect/Bufferization/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 6e9610b5e55830..3747ab1562ff42 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -866,11 +866,60 @@ struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
   }
 };
 
+/// Removes memrefs from the deallocation list if their associated condition is
+/// always 'false'.
+///
+/// Example:
+/// ```
+/// %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
+///                           if (%arg2, %false)
+/// ```
+/// becomes
+/// ```
+/// %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
+/// ```
+struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
+  using OpRewritePattern<DeallocOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DeallocOp deallocOp,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value> newMemrefs, newConditions;
+    SmallVector<Value> replacements;
+
+    for (auto [res, memref, cond] :
+         llvm::zip(deallocOp.getUpdatedConditions(), deallocOp.getMemrefs(),
+                   deallocOp.getConditions())) {
+      if (matchPattern(cond, m_Zero())) {
+        replacements.push_back(cond);
+        continue;
+      }
+      newMemrefs.push_back(memref);
+      newConditions.push_back(cond);
+      replacements.push_back({});
+    }
+
+    if (newMemrefs.size() == deallocOp.getMemrefs().size())
+      return failure();
+
+    auto newDeallocOp = rewriter.create<DeallocOp>(
+        deallocOp.getLoc(), newMemrefs, newConditions, deallocOp.getRetained());
+    unsigned i = 0;
+    for (auto &repl : replacements)
+      if (!repl)
+        repl = newDeallocOp.getResult(i++);
+
+    rewriter.replaceOp(deallocOp, replacements);
+    return success();
+  }
+};
+
 } // anonymous namespace
 
 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<DeallocRemoveDuplicates, EraseEmptyDealloc>(context);
+  results
+      .add<DeallocRemoveDuplicates, EraseEmptyDealloc, EraseAlwaysFalseDealloc>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 90231f36623024..96f82f6835dd61 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -309,3 +309,17 @@ func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %
 //  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
 //  CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
 //  CHECK-NEXT: return [[ARG1]], [[ARG1]], [[V0]] :
+
+// -----
+
+func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: i1) -> (i1, i1) {
+  %false = arith.constant false
+  %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) if (%false, %arg2)
+  return %0#0, %0#1 : i1, i1
+}
+
+// CHECK-LABEL: func @dealloc_always_false_condition
+//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: i1)
+//  CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
+//  CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]])
+//  CHECK-NEXT: return [[FALSE]], [[V0]] :


        


More information about the Mlir-commits mailing list