[Mlir-commits] [mlir] c5e8fbb - [mlir][bufferization] Add pattern to BufferDeallocationSimplification pass that removes unnecessary retain values
Martin Erhart
llvmlistbot at llvm.org
Thu Aug 10 06:47:47 PDT 2023
Author: Martin Erhart
Date: 2023-08-10T13:46:52Z
New Revision: c5e8fbbf7151b0a9ec271f0da13cae7d8f7907fb
URL: https://github.com/llvm/llvm-project/commit/c5e8fbbf7151b0a9ec271f0da13cae7d8f7907fb
DIFF: https://github.com/llvm/llvm-project/commit/c5e8fbbf7151b0a9ec271f0da13cae7d8f7907fb.diff
LOG: [mlir][bufferization] Add pattern to BufferDeallocationSimplification pass that removes unnecessary retain values
Adds a pattern that removes memrefs from the `retained` list which are
guaranteed to not alias any memref in the `memrefs` list. The corresponding
result value can be replaced with `false` in that case according to the
operation description.
When applied after BufferDeallocation, this can considerably reduce the
overhead that needs to be added during the lowering of the dealloc operation to
check for aliasing (especially when there is only one element in the `memref`
list and all `retained` values can be removed).
Depends on D157398
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D157407
Added:
Modified:
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index c5f0450121e73c..bd06b5111b82fd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -154,6 +154,78 @@ struct DeallocRemoveDeallocMemrefsContainedInRetained
AliasAnalysis &aliasAnalysis;
};
+/// Remove memrefs from the `retained` list which are guaranteed to not alias
+/// any memref in the `memrefs` list. The corresponding result value can be
+/// replaced with `false` in that case according to the operation description.
+///
+/// Example:
+/// ```mlir
+/// %0:2 = bufferization.dealloc (%m : memref<2xi32>) if (%cond)
+/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
+/// return %0#0, %0#1
+/// ```
+/// can be canonicalized to the following given that `%r0` and `%r1` do not
+/// alias `%m`:
+/// ```mlir
+/// bufferization.dealloc (%m : memref<2xi32>) if (%cond)
+/// return %false, %false
+/// ```
+struct RemoveRetainedMemrefsGuaranteedToNotAlias
+ : public OpRewritePattern<DeallocOp> {
+ RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
+ AliasAnalysis &aliasAnalysis)
+ : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+
+ bool potentiallyAliasesMemref(DeallocOp deallocOp,
+ Value retainedMemref) const {
+ for (auto memref : deallocOp.getMemrefs()) {
+ if (!aliasAnalysis.alias(memref, retainedMemref).isNo())
+ return true;
+ }
+ return false;
+ }
+
+ LogicalResult matchAndRewrite(DeallocOp deallocOp,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> newRetainedMemrefs, replacements;
+ Value falseValue;
+ auto getOrCreateFalse = [&]() -> Value {
+ if (!falseValue)
+ falseValue = rewriter.create<arith::ConstantOp>(
+ deallocOp.getLoc(), rewriter.getBoolAttr(false));
+ return falseValue;
+ };
+
+ for (auto retainedMemref : deallocOp.getRetained()) {
+ if (potentiallyAliasesMemref(deallocOp, retainedMemref)) {
+ newRetainedMemrefs.push_back(retainedMemref);
+ replacements.push_back({});
+ continue;
+ }
+
+ replacements.push_back(getOrCreateFalse());
+ }
+
+ if (newRetainedMemrefs.size() == deallocOp.getRetained().size())
+ return failure();
+
+ auto newDeallocOp = rewriter.create<DeallocOp>(
+ deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(),
+ newRetainedMemrefs);
+ int i = 0;
+ for (auto &repl : replacements) {
+ if (!repl)
+ repl = newDeallocOp.getUpdatedConditions()[i++];
+ }
+
+ rewriter.replaceOp(deallocOp, replacements);
+ return success();
+ }
+
+private:
+ AliasAnalysis &aliasAnalysis;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -171,8 +243,9 @@ struct BufferDeallocationSimplificationPass
void runOnOperation() override {
AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
RewritePatternSet patterns(&getContext());
- patterns.add<DeallocRemoveDeallocMemrefsContainedInRetained>(&getContext(),
- aliasAnalysis);
+ patterns.add<DeallocRemoveDeallocMemrefsContainedInRetained,
+ RemoveRetainedMemrefsGuaranteedToNotAlias>(&getContext(),
+ aliasAnalysis);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
index e5de8569353e35..79f1972392ffcb 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -9,12 +9,11 @@ func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg
// CHECK-LABEL: func @dealloc_deallocated_in_retained
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
-// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc retain ([[ARG0]] : memref<2xi32>)
-// CHECK-NEXT: [[O0:%.+]] = arith.ori [[V0]], [[ARG1]]
+// CHECK-NEXT: bufferization.dealloc
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
-// CHECK-NEXT: return [[O0]], [[O1]], [[V2]]#0, [[V2]]#1 :
+// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1 :
// -----
@@ -31,9 +30,25 @@ func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
// CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] :
// CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] :
-// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc retain ([[ARG0]] : memref<2xi32>)
-// CHECK-NEXT: [[O0:%.+]] = arith.ori [[V0]], [[ARG1]]
+// CHECK-NEXT: bufferization.dealloc
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
-// CHECK-NEXT: return [[O0]], [[O1]], [[V2]]#0, [[V2]]#1 :
+// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1 :
+
+// -----
+
+func.func @dealloc_deallocated_in_retained(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1) {
+ %alloc = memref.alloc() : memref<2xi32>
+ %alloc0 = memref.alloc() : memref<2xi32>
+ %0:2 = bufferization.dealloc (%alloc : memref<2xi32>) if (%arg0) retain (%alloc0, %arg1 : memref<2xi32>, memref<2xi32>)
+ return %0#0, %0#1 : i1, i1
+}
+
+// CHECK-LABEL: func @dealloc_deallocated_in_retained
+// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xi32>)
+// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
+// CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc(
+// CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : memref<2xi32>) if ([[ARG0]])
+// CHECK-NOT: retain
+// CHECK-NEXT: return [[FALSE]], [[FALSE]] :
More information about the Mlir-commits
mailing list