[Mlir-commits] [mlir] 0bcae5e - [mlir][bufferization] Add pattern to BufferDeallocationSimplification pass
Martin Erhart
llvmlistbot at llvm.org
Tue Aug 15 05:40:35 PDT 2023
Author: Martin Erhart
Date: 2023-08-15T12:39:57Z
New Revision: 0bcae5e763c7bf951bad06f5913a9905f065539d
URL: https://github.com/llvm/llvm-project/commit/0bcae5e763c7bf951bad06f5913a9905f065539d
DIFF: https://github.com/llvm/llvm-project/commit/0bcae5e763c7bf951bad06f5913a9905f065539d.diff
LOG: [mlir][bufferization] Add pattern to BufferDeallocationSimplification pass
Add a pattern that splits one dealloc operation into multiple dealloc operation
depending on static aliasing information of the values in the `memref` operand
list. This reduces the total number of aliasing checks required at runtime and
can enable futher canonicalizations of the new and simplified dealloc
operations.
Depends on D157407
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D157508
Added:
Modified:
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index bd06b5111b82fd..86b615575544b0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -48,6 +48,24 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
return success();
}
+/// Checks if 'memref' may or must alias a MemRef in 'memrefList'. It is often a
+/// requirement of optimization patterns that there cannot be any aliasing
+/// memref in order to perform the desired simplification. The 'allowSelfAlias'
+/// argument indicates whether 'memref' may be present in 'memrefList' which
+/// makes this helper function applicable to situations where we already know
+/// that 'memref' is in the list but also when we don't want it in the list.
+static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
+ ValueRange memrefList, Value memref,
+ bool allowSelfAlias) {
+ for (auto mr : memrefList) {
+ if (allowSelfAlias && mr == memref)
+ continue;
+ if (!analysis.alias(mr, memref).isNo())
+ return true;
+ }
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
@@ -176,15 +194,6 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
AliasAnalysis &aliasAnalysis)
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
- bool potentiallyAliasesMemref(DeallocOp deallocOp,
- Value retainedMemref) const {
- for (auto memref : deallocOp.getMemrefs()) {
- if (!aliasAnalysis.alias(memref, retainedMemref).isNo())
- return true;
- }
- return false;
- }
-
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
SmallVector<Value> newRetainedMemrefs, replacements;
@@ -197,7 +206,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
};
for (auto retainedMemref : deallocOp.getRetained()) {
- if (potentiallyAliasesMemref(deallocOp, retainedMemref)) {
+ if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
+ retainedMemref, false)) {
newRetainedMemrefs.push_back(retainedMemref);
replacements.push_back({});
continue;
@@ -226,6 +236,85 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
AliasAnalysis &aliasAnalysis;
};
+/// Split off memrefs to separate dealloc operations to reduce the number of
+/// runtime checks required and enable further canonicalization of the new and
+/// simpler dealloc operations. A memref can be split off if it is guaranteed to
+/// not alias with any other memref in the `memref` operand list. The results
+/// of the old and the new dealloc operation have to be combined by computing
+/// the element-wise disjunction of them.
+///
+/// Example:
+/// ```mlir
+/// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>)
+/// if (%cond0, %cond1)
+/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
+/// return %0#0, %0#1
+/// ```
+/// Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is
+/// canonicalized to the following, thus reducing the number of runtime alias
+/// checks by 1 and potentially enabling further canonicalization of the new
+/// split-up dealloc operations.
+/// ```mlir
+/// %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0)
+/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
+/// %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1)
+/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
+/// %2 = arith.ori %0#0, %1#0
+/// %3 = arith.ori %0#1, %1#1
+/// return %2, %3
+/// ```
+struct SplitDeallocWhenNotAliasingAnyOther
+ : public OpRewritePattern<DeallocOp> {
+ SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
+ AliasAnalysis &aliasAnalysis)
+ : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+
+ LogicalResult matchAndRewrite(DeallocOp deallocOp,
+ PatternRewriter &rewriter) const override {
+ if (deallocOp.getMemrefs().size() <= 1)
+ return failure();
+
+ SmallVector<Value> newMemrefs, newConditions, replacements;
+ DenseSet<Operation *> exceptedUsers;
+ replacements = deallocOp.getUpdatedConditions();
+ for (auto [memref, cond] :
+ llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+ if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
+ memref, true)) {
+ newMemrefs.push_back(memref);
+ newConditions.push_back(cond);
+ continue;
+ }
+
+ auto newDeallocOp = rewriter.create<DeallocOp>(
+ deallocOp.getLoc(), memref, cond, deallocOp.getRetained());
+ replacements = SmallVector<Value>(llvm::map_range(
+ llvm::zip(replacements, newDeallocOp.getUpdatedConditions()),
+ [&](auto replAndNew) -> Value {
+ auto orOp = rewriter.create<arith::OrIOp>(deallocOp.getLoc(),
+ std::get<0>(replAndNew),
+ std::get<1>(replAndNew));
+ exceptedUsers.insert(orOp);
+ return orOp.getResult();
+ }));
+ }
+
+ if (newMemrefs.size() == deallocOp.getMemrefs().size())
+ return failure();
+
+ rewriter.replaceUsesWithIf(deallocOp.getUpdatedConditions(), replacements,
+ [&](OpOperand &operand) {
+ return !exceptedUsers.contains(
+ operand.getOwner());
+ });
+ return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
+ rewriter);
+ }
+
+private:
+ AliasAnalysis &aliasAnalysis;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -244,8 +333,9 @@ struct BufferDeallocationSimplificationPass
AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
RewritePatternSet patterns(&getContext());
patterns.add<DeallocRemoveDeallocMemrefsContainedInRetained,
- RemoveRetainedMemrefsGuaranteedToNotAlias>(&getContext(),
- aliasAnalysis);
+ RemoveRetainedMemrefsGuaranteedToNotAlias,
+ SplitDeallocWhenNotAliasingAnyOther>(&getContext(),
+ aliasAnalysis);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
index 79f1972392ffcb..b2656cdd2aa18b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -52,3 +52,29 @@ func.func @dealloc_deallocated_in_retained(%arg0: i1, %arg1: memref<2xi32>) -> (
// CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : memref<2xi32>) if ([[ARG0]])
// CHECK-NOT: retain
// CHECK-NEXT: return [[FALSE]], [[FALSE]] :
+
+// -----
+
+func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1) {
+ %alloc = memref.alloc() : memref<2xi32>
+ %alloc0 = memref.alloc() : memref<2xi32>
+ %0 = arith.select %arg0, %alloc, %alloc0 : memref<2xi32>
+ %1:2 = bufferization.dealloc (%alloc, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg0, %arg3) retain (%arg1, %0 : memref<2xi32>, memref<2xi32>)
+ return %1#0, %1#1 : i1, i1
+}
+
+// CHECK-LABEL: func @dealloc_split_when_no_other_aliasing
+// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1)
+// CHECK-NEXT: [[ALLOC0:%.+]] = memref.alloc(
+// CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloc(
+// CHECK-NEXT: [[V0:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]] :
+// COM: there is only one value in the retained list because the
+// COM: RemoveRetainedMemrefsGuaranteedToNotAlias pattern also applies here and
+// COM: removes %arg1 from the list. In the second dealloc, this does not apply
+// COM: because function arguments are assumed potentially alias (even if the
+// COM: types don't exactly match).
+// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>)
+// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]], [[V0]] : memref<2xi32>, memref<2xi32>)
+// CHECK-NEXT: [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1
+// CHECK-NEXT: bufferization.dealloc
+// CHECK-NEXT: return [[V2]]#0, [[V3]] :
More information about the Mlir-commits
mailing list