[Mlir-commits] [mlir] 0bcae5e - [mlir][bufferization] Add pattern to BufferDeallocationSimplification pass

Martin Erhart llvmlistbot at llvm.org
Tue Aug 15 05:40:35 PDT 2023


Author: Martin Erhart
Date: 2023-08-15T12:39:57Z
New Revision: 0bcae5e763c7bf951bad06f5913a9905f065539d

URL: https://github.com/llvm/llvm-project/commit/0bcae5e763c7bf951bad06f5913a9905f065539d
DIFF: https://github.com/llvm/llvm-project/commit/0bcae5e763c7bf951bad06f5913a9905f065539d.diff

LOG: [mlir][bufferization] Add pattern to BufferDeallocationSimplification pass

Add a pattern that splits one dealloc operation into multiple dealloc operation
depending on static aliasing information of the values in the `memref` operand
list. This reduces the total number of aliasing checks required at runtime and
can enable futher canonicalizations of the new and simplified dealloc
operations.

Depends on D157407

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D157508

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
    mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index bd06b5111b82fd..86b615575544b0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -48,6 +48,24 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
   return success();
 }
 
+/// Checks if 'memref' may or must alias a MemRef in 'memrefList'. It is often a
+/// requirement of optimization patterns that there cannot be any aliasing
+/// memref in order to perform the desired simplification. The 'allowSelfAlias'
+/// argument indicates whether 'memref' may be present in 'memrefList' which
+/// makes this helper function applicable to situations where we already know
+/// that 'memref' is in the list but also when we don't want it in the list.
+static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
+                                     ValueRange memrefList, Value memref,
+                                     bool allowSelfAlias) {
+  for (auto mr : memrefList) {
+    if (allowSelfAlias && mr == memref)
+      continue;
+    if (!analysis.alias(mr, memref).isNo())
+      return true;
+  }
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // Patterns
 //===----------------------------------------------------------------------===//
@@ -176,15 +194,6 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
                                             AliasAnalysis &aliasAnalysis)
       : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
 
-  bool potentiallyAliasesMemref(DeallocOp deallocOp,
-                                Value retainedMemref) const {
-    for (auto memref : deallocOp.getMemrefs()) {
-      if (!aliasAnalysis.alias(memref, retainedMemref).isNo())
-        return true;
-    }
-    return false;
-  }
-
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Value> newRetainedMemrefs, replacements;
@@ -197,7 +206,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
     };
 
     for (auto retainedMemref : deallocOp.getRetained()) {
-      if (potentiallyAliasesMemref(deallocOp, retainedMemref)) {
+      if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
+                                   retainedMemref, false)) {
         newRetainedMemrefs.push_back(retainedMemref);
         replacements.push_back({});
         continue;
@@ -226,6 +236,85 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
   AliasAnalysis &aliasAnalysis;
 };
 
+/// Split off memrefs to separate dealloc operations to reduce the number of
+/// runtime checks required and enable further canonicalization of the new and
+/// simpler dealloc operations. A memref can be split off if it is guaranteed to
+/// not alias with any other memref in the `memref` operand list.  The results
+/// of the old and the new dealloc operation have to be combined by computing
+/// the element-wise disjunction of them.
+///
+/// Example:
+/// ```mlir
+/// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>)
+///                           if (%cond0, %cond1)
+///                       retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
+/// return %0#0, %0#1
+/// ```
+/// Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is
+/// canonicalized to the following, thus reducing the number of runtime alias
+/// checks by 1 and potentially enabling further canonicalization of the new
+/// split-up dealloc operations.
+/// ```mlir
+/// %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0)
+///                       retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
+/// %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1)
+///                       retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
+/// %2 = arith.ori %0#0, %1#0
+/// %3 = arith.ori %0#1, %1#1
+/// return %2, %3
+/// ```
+struct SplitDeallocWhenNotAliasingAnyOther
+    : public OpRewritePattern<DeallocOp> {
+  SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
+                                      AliasAnalysis &aliasAnalysis)
+      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+
+  LogicalResult matchAndRewrite(DeallocOp deallocOp,
+                                PatternRewriter &rewriter) const override {
+    if (deallocOp.getMemrefs().size() <= 1)
+      return failure();
+
+    SmallVector<Value> newMemrefs, newConditions, replacements;
+    DenseSet<Operation *> exceptedUsers;
+    replacements = deallocOp.getUpdatedConditions();
+    for (auto [memref, cond] :
+         llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+      if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
+                                   memref, true)) {
+        newMemrefs.push_back(memref);
+        newConditions.push_back(cond);
+        continue;
+      }
+
+      auto newDeallocOp = rewriter.create<DeallocOp>(
+          deallocOp.getLoc(), memref, cond, deallocOp.getRetained());
+      replacements = SmallVector<Value>(llvm::map_range(
+          llvm::zip(replacements, newDeallocOp.getUpdatedConditions()),
+          [&](auto replAndNew) -> Value {
+            auto orOp = rewriter.create<arith::OrIOp>(deallocOp.getLoc(),
+                                                      std::get<0>(replAndNew),
+                                                      std::get<1>(replAndNew));
+            exceptedUsers.insert(orOp);
+            return orOp.getResult();
+          }));
+    }
+
+    if (newMemrefs.size() == deallocOp.getMemrefs().size())
+      return failure();
+
+    rewriter.replaceUsesWithIf(deallocOp.getUpdatedConditions(), replacements,
+                               [&](OpOperand &operand) {
+                                 return !exceptedUsers.contains(
+                                     operand.getOwner());
+                               });
+    return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
+                                  rewriter);
+  }
+
+private:
+  AliasAnalysis &aliasAnalysis;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -244,8 +333,9 @@ struct BufferDeallocationSimplificationPass
     AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
     RewritePatternSet patterns(&getContext());
     patterns.add<DeallocRemoveDeallocMemrefsContainedInRetained,
-                 RemoveRetainedMemrefsGuaranteedToNotAlias>(&getContext(),
-                                                            aliasAnalysis);
+                 RemoveRetainedMemrefsGuaranteedToNotAlias,
+                 SplitDeallocWhenNotAliasingAnyOther>(&getContext(),
+                                                      aliasAnalysis);
 
     if (failed(
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
index 79f1972392ffcb..b2656cdd2aa18b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -52,3 +52,29 @@ func.func @dealloc_deallocated_in_retained(%arg0: i1, %arg1: memref<2xi32>) -> (
 //  CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : memref<2xi32>) if ([[ARG0]])
 //  CHECK-NOT: retain
 //  CHECK-NEXT: return [[FALSE]], [[FALSE]] :
+
+// -----
+
+func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1) {
+  %alloc = memref.alloc() : memref<2xi32>
+  %alloc0 = memref.alloc() : memref<2xi32>
+  %0 = arith.select %arg0, %alloc, %alloc0 : memref<2xi32>
+  %1:2 = bufferization.dealloc (%alloc, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg0, %arg3) retain (%arg1, %0 : memref<2xi32>, memref<2xi32>)
+  return %1#0, %1#1 : i1, i1
+}
+
+// CHECK-LABEL: func @dealloc_split_when_no_other_aliasing
+//  CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1)
+//  CHECK-NEXT:   [[ALLOC0:%.+]] = memref.alloc(
+//  CHECK-NEXT:   [[ALLOC1:%.+]] = memref.alloc(
+//  CHECK-NEXT:   [[V0:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]] :
+// COM: there is only one value in the retained list because the
+// COM: RemoveRetainedMemrefsGuaranteedToNotAlias pattern also applies here and
+// COM: removes %arg1 from the list. In the second dealloc, this does not apply
+// COM: because function arguments are assumed potentially alias (even if the
+// COM: types don't exactly match).
+//  CHECK-NEXT:   [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>)
+//  CHECK-NEXT:   [[V2:%.+]]:2 = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]], [[V0]] : memref<2xi32>, memref<2xi32>)
+//  CHECK-NEXT:   [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1
+//  CHECK-NEXT:   bufferization.dealloc
+//  CHECK-NEXT:   return [[V2]]#0, [[V3]] :


        


More information about the Mlir-commits mailing list