[Mlir-commits] [mlir] 765e82e - [mlir][bufferization] Generalize dealloc simplification pattern

Martin Erhart llvmlistbot at llvm.org
Mon Aug 21 06:08:57 PDT 2023


Author: Martin Erhart
Date: 2023-08-21T13:08:00Z
New Revision: 765e82eeb6ca24b5bc6554b89c776fd250e72fdd

URL: https://github.com/llvm/llvm-project/commit/765e82eeb6ca24b5bc6554b89c776fd250e72fdd
DIFF: https://github.com/llvm/llvm-project/commit/765e82eeb6ca24b5bc6554b89c776fd250e72fdd.diff

LOG: [mlir][bufferization] Generalize dealloc simplification pattern

We are allowed to remove any values from the `memref` list for which there is no
memref in the `retained` list with a may-alias relation. Before removing, we
just have to make sure that the corresponding op results for all retained
memrefs with must-alias relation  are updated accordingly. This means, the the
condition operand has to become part of the disjunction the result value is
computed with.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158395

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
    mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 86b615575544b0..41b25bbd2e7d26 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -73,19 +73,19 @@ static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
 namespace {
 
 /// Remove values from the `memref` operand list that are also present in the
-/// `retained` list since they will always alias and thus never actually be
-/// deallocated. However, we also need to be certain that no other value in the
-/// `retained` list can alias, for which we use a static alias analysis. This is
-/// necessary because the `dealloc` operation is defined to return one `i1`
-/// value per memref in the `retained` list which represents the disjunction of
-/// the condition values corresponding to all aliasing values in the `memref`
-/// list. In particular, this means that if there is some value R in the
-/// `retained` list which aliases with a value M in the `memref` list (but can
-/// only be staticaly determined to may-alias) and M is also present in the
-/// `retained` list, then it would be illegal to remove M because the result
-/// corresponding to R would be computed incorrectly afterwards.
-/// Because we require an alias analysis, this pattern cannot be applied as a
-/// regular canonicalization pattern.
+/// `retained` list (or a guaranteed alias of it) because they will never
+/// actually be deallocated. However, we also need to be certain about which
+/// other memrefs in the `retained` list can alias, i.e., there must not by any
+/// may-aliasing memref. This is necessary because the `dealloc` operation is
+/// defined to return one `i1` value per memref in the `retained` list which
+/// represents the disjunction of the condition values corresponding to all
+/// aliasing values in the `memref` list. In particular, this means that if
+/// there is some value R in the `retained` list which aliases with a value M in
+/// the `memref` list (but can only be staticaly determined to may-alias) and M
+/// is also present in the `retained` list, then it would be illegal to remove M
+/// because the result corresponding to R would be computed incorrectly
+/// afterwards.  Because we require an alias analysis, this pattern cannot be
+/// applied as a regular canonicalization pattern.
 ///
 /// Example:
 /// ```mlir
@@ -101,63 +101,75 @@ namespace {
 /// // replace %0#0 with %1
 /// ```
 /// given that `%r0` and `%r1` may not alias with `%m0`.
-struct DeallocRemoveDeallocMemrefsContainedInRetained
+struct RemoveDeallocMemrefsContainedInRetained
     : public OpRewritePattern<DeallocOp> {
-  DeallocRemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
-                                                 AliasAnalysis &aliasAnalysis)
+  RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
+                                          AliasAnalysis &aliasAnalysis)
       : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
 
+  /// The passed 'memref' must not have a may-alias relation to any retained
+  /// memref, and at least one must-alias relation. If there is no must-aliasing
+  /// memref in the retain list, we cannot simply remove the memref as there
+  /// could be situations in which it actually has to be deallocated. If it's
+  /// no-alias, then just proceed, if it's must-alias we need to update the
+  /// updated condition returned by the dealloc operation for that alias.
+  LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond,
+                                PatternRewriter &rewriter) const {
+    rewriter.setInsertionPointAfter(deallocOp);
+
+    // Check that there is no may-aliasing memref and that at least one memref
+    // in the retain list aliases (because otherwise it might have to be
+    // deallocated in some situations and can thus not be dropped).
+    bool atLeastOneMustAlias = false;
+    for (Value retained : deallocOp.getRetained()) {
+      AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
+      if (analysisResult.isMay())
+        return failure();
+      if (analysisResult.isMust() || analysisResult.isPartial())
+        atLeastOneMustAlias = true;
+    }
+    if (!atLeastOneMustAlias)
+      return failure();
+
+    // Insert arith.ori operations to update the corresponding dealloc result
+    // values to incorporate the condition of the must-aliasing memref such that
+    // we can remove that operand later on.
+    for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
+      Value updatedCondition = deallocOp.getUpdatedConditions()[i];
+      AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
+      if (analysisResult.isMust() || analysisResult.isPartial()) {
+        auto disjunction = rewriter.create<arith::OrIOp>(
+            deallocOp.getLoc(), updatedCondition, cond);
+        rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
+                                      disjunction);
+      }
+    }
+
+    return success();
+  }
+
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
-    // Unique memrefs to be deallocated.
-    DenseMap<Value, unsigned> retained;
-    for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained()))
-      retained[ret] = i;
-
     // There must not be any duplicates in the retain list anymore because we
     // would miss updating one of the result values otherwise.
+    DenseSet<Value> retained(deallocOp.getRetained().begin(),
+                             deallocOp.getRetained().end());
     if (retained.size() != deallocOp.getRetained().size())
       return failure();
 
     SmallVector<Value> newMemrefs, newConditions;
-    for (auto memrefAndCond :
+    for (auto [memref, cond] :
          llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
-      Value memref = std::get<0>(memrefAndCond);
-      Value cond = std::get<1>(memrefAndCond);
-
-      auto replaceResultsIfNoInvalidAliasing = [&](Value memref) -> bool {
-        Value retainedMemref = deallocOp.getRetained()[retained[memref]];
-        // The current memref must not have a may-alias relation to any retained
-        // memref, and exactly one must-alias relation.
-        // TODO: it is possible to extend this pattern to allow an arbitrary
-        // number of must-alias relations as long as there is no may-alias. If
-        // it's no-alias, then just proceed (only supported case as of now), if
-        // it's must-alias, we also need to update the condition for that alias.
-        if (llvm::all_of(deallocOp.getRetained(), [&](Value mr) {
-              return aliasAnalysis.alias(mr, memref).isNo() ||
-                     mr == retainedMemref;
-            })) {
-          rewriter.setInsertionPointAfter(deallocOp);
-          auto orOp = rewriter.create<arith::OrIOp>(
-              deallocOp.getLoc(),
-              deallocOp.getUpdatedConditions()[retained[memref]], cond);
-          rewriter.replaceAllUsesExcept(
-              deallocOp.getUpdatedConditions()[retained[memref]],
-              orOp.getResult(), orOp);
-          return true;
-        }
-        return false;
-      };
-
-      if (retained.contains(memref) &&
-          replaceResultsIfNoInvalidAliasing(memref))
-        continue;
 
-      auto extractOp = memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
-      if (extractOp && retained.contains(extractOp.getOperand()) &&
-          replaceResultsIfNoInvalidAliasing(extractOp.getOperand()))
+      if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))
         continue;
 
+      if (auto extractOp =
+              memref.getDefiningOp<memref::ExtractStridedMetadataOp>())
+        if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,
+                                      rewriter)))
+          continue;
+
       newMemrefs.push_back(memref);
       newConditions.push_back(cond);
     }
@@ -332,7 +344,7 @@ struct BufferDeallocationSimplificationPass
   void runOnOperation() override {
     AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
     RewritePatternSet patterns(&getContext());
-    patterns.add<DeallocRemoveDeallocMemrefsContainedInRetained,
+    patterns.add<RemoveDeallocMemrefsContainedInRetained,
                  RemoveRetainedMemrefsGuaranteedToNotAlias,
                  SplitDeallocWhenNotAliasingAnyOther>(&getContext(),
                                                       aliasAnalysis);

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
index b2656cdd2aa18b..6c63166015e843 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -1,51 +1,79 @@
 // RUN: mlir-opt %s --buffer-deallocation-simplification --split-input-file | FileCheck %s
 
-func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) {
+func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1, i1, i1, i1, i1, i1) {
   %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>)
   %1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
   %2:2 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>)
-  return %0, %1, %2#0, %2#1 : i1, i1, i1, i1
+  // multiple must-alias
+  %3 = memref.subview %arg0[0][1][1] : memref<2xi32> to memref<i32>
+  %4 = memref.subview %arg0[1][1][1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>>
+  %alloc = memref.alloc() : memref<2xi32>
+  %5:3 = bufferization.dealloc (%arg0, %4 : memref<2xi32>, memref<1xi32, strided<[1], offset: 1>>) if (%arg1, %arg3) retain (%arg0, %alloc, %3 : memref<2xi32>, memref<2xi32>, memref<i32>)
+  return %0, %1, %2#0, %2#1, %5#0, %5#1, %5#2 : i1, i1, i1, i1, i1, i1, i1
 }
 
 // CHECK-LABEL: func @dealloc_deallocated_in_retained
-//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
+//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1)
+//  CHECK-NEXT: arith.constant false
 //  CHECK-NEXT: bufferization.dealloc
 //  CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
 //  CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
 //  CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
-//  CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1 :
+// COM: the RemoveRetainedMemrefsGuaranteedToNotAlias pattern removes all the
+// COM: retained memrefs since the list of memrefs to be deallocated becomes empty
+// COM: due to the pattern under test (and thus there is no memref the retain values
+// COM: could alias to)
+//  CHECK-NEXT: bufferization.dealloc
+// CHECK-NOT: if
+//  CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]]
+//  CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]]
+//  CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1, [[V3]], %false{{[0-9_]*}}, [[V4]] :
 
 // -----
 
-func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) {
+func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1, i1, i1, i1, i1, i1) {
   %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %arg0 : memref<2xi32> -> memref<i32>, index, index, index
   %base_buffer0, %offset0, %size0, %stride0 = memref.extract_strided_metadata %arg2 : memref<2xi32> -> memref<i32>, index, index, index
   %0 = bufferization.dealloc (%base_buffer : memref<i32>) if (%arg1) retain (%arg0 : memref<2xi32>)
   %1 = bufferization.dealloc (%base_buffer, %base_buffer0 : memref<i32>, memref<i32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
   %2:2 = bufferization.dealloc (%base_buffer : memref<i32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>)
-  return %0, %1, %2#0, %2#1 : i1, i1, i1, i1
+  // multiple must-alias
+  %3 = memref.subview %arg0[0][1][1] : memref<2xi32> to memref<i32>
+  %4 = memref.subview %arg0[1][1][1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>>
+  %alloc = memref.alloc() : memref<2xi32>
+  %5:3 = bufferization.dealloc (%base_buffer, %4 : memref<i32>, memref<1xi32, strided<[1], offset: 1>>) if (%arg1, %arg3) retain (%arg0, %alloc, %3 : memref<2xi32>, memref<2xi32>, memref<i32>)
+  return %0, %1, %2#0, %2#1, %5#0, %5#1, %5#2 : i1, i1, i1, i1, i1, i1, i1
 }
 
 // CHECK-LABEL: func @dealloc_deallocated_in_retained_extract_base_memref
-//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
+//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1)
+//  CHECK-NEXT: arith.constant false
 //  CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] :
 //  CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] :
 //  CHECK-NEXT: bufferization.dealloc
 //  CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
 //  CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
 //  CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
-//  CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1 :
+// COM: the RemoveRetainedMemrefsGuaranteedToNotAlias pattern removes all the
+// COM: retained memrefs since the list of memrefs to be deallocated becomes empty
+// COM: due to the pattern under test (and thus there is no memref the retain values
+// COM: could alias to)
+//  CHECK-NEXT: bufferization.dealloc
+// CHECK-NOT: if
+//  CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]]
+//  CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]]
+//  CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1, [[V3]], %false{{[0-9_]*}}, [[V4]] :
 
 // -----
 
-func.func @dealloc_deallocated_in_retained(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1) {
+func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1) {
   %alloc = memref.alloc() : memref<2xi32>
   %alloc0 = memref.alloc() : memref<2xi32>
   %0:2 = bufferization.dealloc (%alloc : memref<2xi32>) if (%arg0) retain (%alloc0, %arg1 : memref<2xi32>, memref<2xi32>)
   return %0#0, %0#1 : i1, i1
 }
 
-// CHECK-LABEL: func @dealloc_deallocated_in_retained
+// CHECK-LABEL: func @remove_retained_memrefs_guarateed_to_not_alias
 //  CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xi32>)
 //  CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
 //  CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc(


        


More information about the Mlir-commits mailing list