[Mlir-commits] [mlir] fea185d - [mlir][bufferization] Add pattern to BufferDeallocationSimplification pass

Martin Erhart llvmlistbot at llvm.org
Wed Aug 23 03:41:49 PDT 2023


Author: Martin Erhart
Date: 2023-08-23T10:41:05Z
New Revision: fea185d70d7e8699dd5a117743ad7b993f4ea21e

URL: https://github.com/llvm/llvm-project/commit/fea185d70d7e8699dd5a117743ad7b993f4ea21e
DIFF: https://github.com/llvm/llvm-project/commit/fea185d70d7e8699dd5a117743ad7b993f4ea21e.diff

LOG: [mlir][bufferization] Add pattern to BufferDeallocationSimplification pass

This new pattern allows us to simplify the dealloc result value (by replacing
it with a constant 'true') and to trim the 'memref' operand list when we know
that all retained memrefs alias with one in the 'memref' list that has a
constant 'true' condition. Because the conditions of aliasing memrefs are
combined by disjunction, we know that once a single constant 'true' value is in
the disjunction the remaining elements don't matter anymore. This complements
the RemoveDeallocMemrefsContainedInRetained pattern which removes values from
the 'memref' list when static information is available for all retained values
by also allowing to remove values in the presence of may-aliases, but under
above mentioned condition instead.
The BufferDeallocation pass often adds dealloc operations where the memref and
retain lists are the same and all conditions are 'true'. If the operands are
all function arguments, for example, they are always determined to may-alias
which renders the other patterns invalid, but the op could still be trivially
optimized away. It would even be enough to directly compare the two operand
lists and check the conditions are all constant 'true' (plus checking for the
extract_strided_metadata operation), but this pattern is a bit more general and
still works when there are additional memrefs in the 'memref' list that actually
have to be deallocated (e.g., see regression test).

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158518

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
    mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 41b25bbd2e7d26..1090cd668ca78c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
@@ -327,6 +328,88 @@ struct SplitDeallocWhenNotAliasingAnyOther
   AliasAnalysis &aliasAnalysis;
 };
 
+/// Check for every retained memref if a must-aliasing memref exists in the
+/// 'memref' operand list with constant 'true' condition. If so, we can replace
+/// the operation result corresponding to that retained memref with 'true'. If
+/// this condition holds for all retained memrefs we can also remove the
+/// aliasing memrefs and their conditions since they will never be deallocated
+/// due to the must-alias and we don't need them to compute the result value
+/// anymore since it got replaced with 'true'.
+///
+/// Example:
+/// ```mlir
+/// %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : ...)
+///                           if (%true, %true, %true)
+///                       retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
+/// ```
+/// becomes
+/// ```mlir
+/// %0:2 = bufferization.dealloc (%arg2 : memref<2xi32>) if (%true)
+///                       retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
+/// // replace %0#0 with %true
+/// // replace %0#1 with %true
+/// ```
+/// Note that the dealloc operation will still have the result values, but they
+/// don't have uses anymore.
+struct RetainedMemrefAliasingAlwaysDeallocatedMemref
+    : public OpRewritePattern<DeallocOp> {
+  RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
+                                                AliasAnalysis &aliasAnalysis)
+      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+
+  LogicalResult matchAndRewrite(DeallocOp deallocOp,
+                                PatternRewriter &rewriter) const override {
+    BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size());
+    SmallVector<Value> newMemrefs, newConditions;
+    for (auto [memref, cond] :
+         llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+      bool canDropMemref = false;
+      for (auto [i, retained, res] : llvm::enumerate(
+               deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
+        if (!matchPattern(cond, m_One()))
+          continue;
+
+        AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
+        if (analysisResult.isMust() || analysisResult.isPartial()) {
+          rewriter.replaceAllUsesWith(res, cond);
+          aliasesWithConstTrueMemref[i] = true;
+          canDropMemref = true;
+          continue;
+        }
+
+        // TODO: once our alias analysis is powerful enough we can remove the
+        // rest of this loop body
+        auto extractOp =
+            memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
+        if (!extractOp)
+          continue;
+
+        AliasResult extractAnalysisResult =
+            aliasAnalysis.alias(retained, extractOp.getOperand());
+        if (extractAnalysisResult.isMust() ||
+            extractAnalysisResult.isPartial()) {
+          rewriter.replaceAllUsesWith(res, cond);
+          aliasesWithConstTrueMemref[i] = true;
+          canDropMemref = true;
+        }
+      }
+
+      if (!canDropMemref) {
+        newMemrefs.push_back(memref);
+        newConditions.push_back(cond);
+      }
+    }
+    if (!aliasesWithConstTrueMemref.all())
+      return failure();
+
+    return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
+                                  rewriter);
+  }
+
+private:
+  AliasAnalysis &aliasAnalysis;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -346,8 +429,9 @@ struct BufferDeallocationSimplificationPass
     RewritePatternSet patterns(&getContext());
     patterns.add<RemoveDeallocMemrefsContainedInRetained,
                  RemoveRetainedMemrefsGuaranteedToNotAlias,
-                 SplitDeallocWhenNotAliasingAnyOther>(&getContext(),
-                                                      aliasAnalysis);
+                 SplitDeallocWhenNotAliasingAnyOther,
+                 RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
+                                                                aliasAnalysis);
 
     if (failed(
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
index 6c63166015e843..473ce8734f0cd9 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -106,3 +106,35 @@ func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>,
 //  CHECK-NEXT:   [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1
 //  CHECK-NEXT:   bufferization.dealloc
 //  CHECK-NEXT:   return [[V2]]#0, [[V3]] :
+
+// -----
+
+func.func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition(
+  %arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: memref<2xi32>) -> (memref<2xi32>, memref<2xi32>, i1, i1) {
+  %true = arith.constant true
+  %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%true, %true, %true) retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
+  return %arg0, %arg1, %0#0, %0#1 : memref<2xi32>, memref<2xi32>, i1, i1
+}
+
+// CHECK-LABEL: func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition
+//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>)
+//       CHECK:   bufferization.dealloc ([[ARG2]] :{{.*}}) if (%true{{[0-9_]*}})
+//  CHECK-NEXT:   return [[ARG0]], [[ARG1]], %true{{[0-9_]*}}, %true{{[0-9_]*}} :
+
+// -----
+
+func.func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition(
+  %arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: memref<2xi32>) -> (memref<2xi32>, memref<2xi32>, i1, i1) {
+  %true = arith.constant true
+  %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %arg0 : memref<2xi32> -> memref<i32>, index, index, index
+  %base_buffer_1, %offset_1, %size_1, %stride_1 = memref.extract_strided_metadata %arg1 : memref<2xi32> -> memref<i32>, index, index, index
+  %base_buffer_2, %offset_2, %size_2, %stride_2 = memref.extract_strided_metadata %arg2 : memref<2xi32> -> memref<i32>, index, index, index
+  %0:2 = bufferization.dealloc (%base_buffer, %base_buffer_1, %base_buffer_2 : memref<i32>, memref<i32>, memref<i32>) if (%true, %true, %true) retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
+  return %arg0, %arg1, %0#0, %0#1 : memref<2xi32>, memref<2xi32>, i1, i1
+}
+
+// CHECK-LABEL: func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition
+//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>)
+//       CHECK:   [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]]
+//       CHECK:   bufferization.dealloc ([[BASE]] :{{.*}}) if (%true{{[0-9_]*}})
+//  CHECK-NEXT:   return [[ARG0]], [[ARG1]], %true{{[0-9_]*}}, %true{{[0-9_]*}} :


        


More information about the Mlir-commits mailing list