[Mlir-commits] [mlir] d26eb82 - [mlir][bufferization] DeallocOp canonicalizer removing memrefs that are never deallocated
Martin Erhart
llvmlistbot at llvm.org
Wed Aug 2 12:23:20 PDT 2023
Author: Martin Erhart
Date: 2023-08-02T19:12:30Z
New Revision: d26eb822df4131ec0d5ca0b0b91501da53683337
URL: https://github.com/llvm/llvm-project/commit/d26eb822df4131ec0d5ca0b0b91501da53683337
DIFF: https://github.com/llvm/llvm-project/commit/d26eb822df4131ec0d5ca0b0b91501da53683337.diff
LOG: [mlir][bufferization] DeallocOp canonicalizer removing memrefs that are never deallocated
This simplifies the op and avoids unnecessary alias checks introduced during the lowering to memref.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D156807
Added:
Modified:
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/test/Dialect/Bufferization/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 6e9610b5e55830..3747ab1562ff42 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -866,11 +866,60 @@ struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
}
};
+/// Removes memrefs from the deallocation list if their associated condition is
+/// always 'false'.
+///
+/// Example:
+/// ```
+/// %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
+/// if (%arg2, %false)
+/// ```
+/// becomes
+/// ```
+/// %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
+/// ```
+struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
+ using OpRewritePattern<DeallocOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DeallocOp deallocOp,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> newMemrefs, newConditions;
+ SmallVector<Value> replacements;
+
+ for (auto [res, memref, cond] :
+ llvm::zip(deallocOp.getUpdatedConditions(), deallocOp.getMemrefs(),
+ deallocOp.getConditions())) {
+ if (matchPattern(cond, m_Zero())) {
+ replacements.push_back(cond);
+ continue;
+ }
+ newMemrefs.push_back(memref);
+ newConditions.push_back(cond);
+ replacements.push_back({});
+ }
+
+ if (newMemrefs.size() == deallocOp.getMemrefs().size())
+ return failure();
+
+ auto newDeallocOp = rewriter.create<DeallocOp>(
+ deallocOp.getLoc(), newMemrefs, newConditions, deallocOp.getRetained());
+ unsigned i = 0;
+ for (auto &repl : replacements)
+ if (!repl)
+ repl = newDeallocOp.getResult(i++);
+
+ rewriter.replaceOp(deallocOp, replacements);
+ return success();
+ }
+};
+
} // anonymous namespace
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DeallocRemoveDuplicates, EraseEmptyDealloc>(context);
+ results
+ .add<DeallocRemoveDuplicates, EraseEmptyDealloc, EraseAlwaysFalseDealloc>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 90231f36623024..96f82f6835dd61 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -309,3 +309,17 @@ func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
// CHECK-NEXT: return [[ARG1]], [[ARG1]], [[V0]] :
+
+// -----
+
+func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: i1) -> (i1, i1) {
+ %false = arith.constant false
+ %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) if (%false, %arg2)
+ return %0#0, %0#1 : i1, i1
+}
+
+// CHECK-LABEL: func @dealloc_always_false_condition
+// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: i1)
+// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
+// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]])
+// CHECK-NEXT: return [[FALSE]], [[V0]] :
More information about the Mlir-commits
mailing list