[Mlir-commits] [mlir] 765e82e - [mlir][bufferization] Generalize dealloc simplification pattern
Martin Erhart
llvmlistbot at llvm.org
Mon Aug 21 06:08:57 PDT 2023
Author: Martin Erhart
Date: 2023-08-21T13:08:00Z
New Revision: 765e82eeb6ca24b5bc6554b89c776fd250e72fdd
URL: https://github.com/llvm/llvm-project/commit/765e82eeb6ca24b5bc6554b89c776fd250e72fdd
DIFF: https://github.com/llvm/llvm-project/commit/765e82eeb6ca24b5bc6554b89c776fd250e72fdd.diff
LOG: [mlir][bufferization] Generalize dealloc simplification pattern
We are allowed to remove any values from the `memref` list for which there is no
memref in the `retained` list with a may-alias relation. Before removing, we
just have to make sure that the corresponding op results for all retained
memrefs with must-alias relation are updated accordingly. This means, the the
condition operand has to become part of the disjunction the result value is
computed with.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D158395
Added:
Modified:
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 86b615575544b0..41b25bbd2e7d26 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -73,19 +73,19 @@ static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
namespace {
/// Remove values from the `memref` operand list that are also present in the
-/// `retained` list since they will always alias and thus never actually be
-/// deallocated. However, we also need to be certain that no other value in the
-/// `retained` list can alias, for which we use a static alias analysis. This is
-/// necessary because the `dealloc` operation is defined to return one `i1`
-/// value per memref in the `retained` list which represents the disjunction of
-/// the condition values corresponding to all aliasing values in the `memref`
-/// list. In particular, this means that if there is some value R in the
-/// `retained` list which aliases with a value M in the `memref` list (but can
-/// only be staticaly determined to may-alias) and M is also present in the
-/// `retained` list, then it would be illegal to remove M because the result
-/// corresponding to R would be computed incorrectly afterwards.
-/// Because we require an alias analysis, this pattern cannot be applied as a
-/// regular canonicalization pattern.
+/// `retained` list (or a guaranteed alias of it) because they will never
+/// actually be deallocated. However, we also need to be certain about which
+/// other memrefs in the `retained` list can alias, i.e., there must not by any
+/// may-aliasing memref. This is necessary because the `dealloc` operation is
+/// defined to return one `i1` value per memref in the `retained` list which
+/// represents the disjunction of the condition values corresponding to all
+/// aliasing values in the `memref` list. In particular, this means that if
+/// there is some value R in the `retained` list which aliases with a value M in
+/// the `memref` list (but can only be staticaly determined to may-alias) and M
+/// is also present in the `retained` list, then it would be illegal to remove M
+/// because the result corresponding to R would be computed incorrectly
+/// afterwards. Because we require an alias analysis, this pattern cannot be
+/// applied as a regular canonicalization pattern.
///
/// Example:
/// ```mlir
@@ -101,63 +101,75 @@ namespace {
/// // replace %0#0 with %1
/// ```
/// given that `%r0` and `%r1` may not alias with `%m0`.
-struct DeallocRemoveDeallocMemrefsContainedInRetained
+struct RemoveDeallocMemrefsContainedInRetained
: public OpRewritePattern<DeallocOp> {
- DeallocRemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
- AliasAnalysis &aliasAnalysis)
+ RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
+ AliasAnalysis &aliasAnalysis)
: OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+ /// The passed 'memref' must not have a may-alias relation to any retained
+ /// memref, and at least one must-alias relation. If there is no must-aliasing
+ /// memref in the retain list, we cannot simply remove the memref as there
+ /// could be situations in which it actually has to be deallocated. If it's
+ /// no-alias, then just proceed, if it's must-alias we need to update the
+ /// updated condition returned by the dealloc operation for that alias.
+ LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond,
+ PatternRewriter &rewriter) const {
+ rewriter.setInsertionPointAfter(deallocOp);
+
+ // Check that there is no may-aliasing memref and that at least one memref
+ // in the retain list aliases (because otherwise it might have to be
+ // deallocated in some situations and can thus not be dropped).
+ bool atLeastOneMustAlias = false;
+ for (Value retained : deallocOp.getRetained()) {
+ AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
+ if (analysisResult.isMay())
+ return failure();
+ if (analysisResult.isMust() || analysisResult.isPartial())
+ atLeastOneMustAlias = true;
+ }
+ if (!atLeastOneMustAlias)
+ return failure();
+
+ // Insert arith.ori operations to update the corresponding dealloc result
+ // values to incorporate the condition of the must-aliasing memref such that
+ // we can remove that operand later on.
+ for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
+ Value updatedCondition = deallocOp.getUpdatedConditions()[i];
+ AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
+ if (analysisResult.isMust() || analysisResult.isPartial()) {
+ auto disjunction = rewriter.create<arith::OrIOp>(
+ deallocOp.getLoc(), updatedCondition, cond);
+ rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
+ disjunction);
+ }
+ }
+
+ return success();
+ }
+
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
- // Unique memrefs to be deallocated.
- DenseMap<Value, unsigned> retained;
- for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained()))
- retained[ret] = i;
-
// There must not be any duplicates in the retain list anymore because we
// would miss updating one of the result values otherwise.
+ DenseSet<Value> retained(deallocOp.getRetained().begin(),
+ deallocOp.getRetained().end());
if (retained.size() != deallocOp.getRetained().size())
return failure();
SmallVector<Value> newMemrefs, newConditions;
- for (auto memrefAndCond :
+ for (auto [memref, cond] :
llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
- Value memref = std::get<0>(memrefAndCond);
- Value cond = std::get<1>(memrefAndCond);
-
- auto replaceResultsIfNoInvalidAliasing = [&](Value memref) -> bool {
- Value retainedMemref = deallocOp.getRetained()[retained[memref]];
- // The current memref must not have a may-alias relation to any retained
- // memref, and exactly one must-alias relation.
- // TODO: it is possible to extend this pattern to allow an arbitrary
- // number of must-alias relations as long as there is no may-alias. If
- // it's no-alias, then just proceed (only supported case as of now), if
- // it's must-alias, we also need to update the condition for that alias.
- if (llvm::all_of(deallocOp.getRetained(), [&](Value mr) {
- return aliasAnalysis.alias(mr, memref).isNo() ||
- mr == retainedMemref;
- })) {
- rewriter.setInsertionPointAfter(deallocOp);
- auto orOp = rewriter.create<arith::OrIOp>(
- deallocOp.getLoc(),
- deallocOp.getUpdatedConditions()[retained[memref]], cond);
- rewriter.replaceAllUsesExcept(
- deallocOp.getUpdatedConditions()[retained[memref]],
- orOp.getResult(), orOp);
- return true;
- }
- return false;
- };
-
- if (retained.contains(memref) &&
- replaceResultsIfNoInvalidAliasing(memref))
- continue;
- auto extractOp = memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
- if (extractOp && retained.contains(extractOp.getOperand()) &&
- replaceResultsIfNoInvalidAliasing(extractOp.getOperand()))
+ if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))
continue;
+ if (auto extractOp =
+ memref.getDefiningOp<memref::ExtractStridedMetadataOp>())
+ if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,
+ rewriter)))
+ continue;
+
newMemrefs.push_back(memref);
newConditions.push_back(cond);
}
@@ -332,7 +344,7 @@ struct BufferDeallocationSimplificationPass
void runOnOperation() override {
AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
RewritePatternSet patterns(&getContext());
- patterns.add<DeallocRemoveDeallocMemrefsContainedInRetained,
+ patterns.add<RemoveDeallocMemrefsContainedInRetained,
RemoveRetainedMemrefsGuaranteedToNotAlias,
SplitDeallocWhenNotAliasingAnyOther>(&getContext(),
aliasAnalysis);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
index b2656cdd2aa18b..6c63166015e843 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -1,51 +1,79 @@
// RUN: mlir-opt %s --buffer-deallocation-simplification --split-input-file | FileCheck %s
-func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) {
+func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1, i1, i1, i1, i1, i1) {
%0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>)
%1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
%2:2 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>)
- return %0, %1, %2#0, %2#1 : i1, i1, i1, i1
+ // multiple must-alias
+ %3 = memref.subview %arg0[0][1][1] : memref<2xi32> to memref<i32>
+ %4 = memref.subview %arg0[1][1][1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>>
+ %alloc = memref.alloc() : memref<2xi32>
+ %5:3 = bufferization.dealloc (%arg0, %4 : memref<2xi32>, memref<1xi32, strided<[1], offset: 1>>) if (%arg1, %arg3) retain (%arg0, %alloc, %3 : memref<2xi32>, memref<2xi32>, memref<i32>)
+ return %0, %1, %2#0, %2#1, %5#0, %5#1, %5#2 : i1, i1, i1, i1, i1, i1, i1
}
// CHECK-LABEL: func @dealloc_deallocated_in_retained
-// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
+// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1)
+// CHECK-NEXT: arith.constant false
// CHECK-NEXT: bufferization.dealloc
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
-// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1 :
+// COM: the RemoveRetainedMemrefsGuaranteedToNotAlias pattern removes all the
+// COM: retained memrefs since the list of memrefs to be deallocated becomes empty
+// COM: due to the pattern under test (and thus there is no memref the retain values
+// COM: could alias to)
+// CHECK-NEXT: bufferization.dealloc
+// CHECK-NOT: if
+// CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]]
+// CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]]
+// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1, [[V3]], %false{{[0-9_]*}}, [[V4]] :
// -----
-func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) {
+func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1, i1, i1, i1, i1, i1) {
%base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %arg0 : memref<2xi32> -> memref<i32>, index, index, index
%base_buffer0, %offset0, %size0, %stride0 = memref.extract_strided_metadata %arg2 : memref<2xi32> -> memref<i32>, index, index, index
%0 = bufferization.dealloc (%base_buffer : memref<i32>) if (%arg1) retain (%arg0 : memref<2xi32>)
%1 = bufferization.dealloc (%base_buffer, %base_buffer0 : memref<i32>, memref<i32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
%2:2 = bufferization.dealloc (%base_buffer : memref<i32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>)
- return %0, %1, %2#0, %2#1 : i1, i1, i1, i1
+ // multiple must-alias
+ %3 = memref.subview %arg0[0][1][1] : memref<2xi32> to memref<i32>
+ %4 = memref.subview %arg0[1][1][1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>>
+ %alloc = memref.alloc() : memref<2xi32>
+ %5:3 = bufferization.dealloc (%base_buffer, %4 : memref<i32>, memref<1xi32, strided<[1], offset: 1>>) if (%arg1, %arg3) retain (%arg0, %alloc, %3 : memref<2xi32>, memref<2xi32>, memref<i32>)
+ return %0, %1, %2#0, %2#1, %5#0, %5#1, %5#2 : i1, i1, i1, i1, i1, i1, i1
}
// CHECK-LABEL: func @dealloc_deallocated_in_retained_extract_base_memref
-// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
+// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1)
+// CHECK-NEXT: arith.constant false
// CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] :
// CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] :
// CHECK-NEXT: bufferization.dealloc
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
-// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1 :
+// COM: the RemoveRetainedMemrefsGuaranteedToNotAlias pattern removes all the
+// COM: retained memrefs since the list of memrefs to be deallocated becomes empty
+// COM: due to the pattern under test (and thus there is no memref the retain values
+// COM: could alias to)
+// CHECK-NEXT: bufferization.dealloc
+// CHECK-NOT: if
+// CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]]
+// CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]]
+// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1, [[V3]], %false{{[0-9_]*}}, [[V4]] :
// -----
-func.func @dealloc_deallocated_in_retained(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1) {
+func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1) {
%alloc = memref.alloc() : memref<2xi32>
%alloc0 = memref.alloc() : memref<2xi32>
%0:2 = bufferization.dealloc (%alloc : memref<2xi32>) if (%arg0) retain (%alloc0, %arg1 : memref<2xi32>, memref<2xi32>)
return %0#0, %0#1 : i1, i1
}
-// CHECK-LABEL: func @dealloc_deallocated_in_retained
+// CHECK-LABEL: func @remove_retained_memrefs_guarateed_to_not_alias
// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xi32>)
// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
// CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc(
More information about the Mlir-commits
mailing list