[Mlir-commits] [mlir] [mlir][bufferization] Add `BufferOriginAnalysis` (PR #86461)

Matthias Springer llvmlistbot at llvm.org
Sun Mar 24 19:58:36 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/86461

>From be4524ef37903229a57159a6d7b668c83551b7a3 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 25 Mar 2024 02:57:40 +0000
Subject: [PATCH] [mlir][bufferization] BufferOriginAnalysis

---
 .../Bufferization/IR/BufferizationOps.td      |   1 +
 .../Transforms/BufferViewFlowAnalysis.h       |  36 ++++
 .../BufferDeallocationSimplification.cpp      |  80 ++++-----
 .../Transforms/BufferViewFlowAnalysis.cpp     | 160 ++++++++++++++++--
 .../dealloc-loops.mlir                        |  86 ++++++++++
 .../buffer-deallocation-simplification.mlir   |  14 +-
 6 files changed, 321 insertions(+), 56 deletions(-)
 create mode 100644 mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-loops.mlir

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 9dc6afcaab31c8..4f609ddff9a413 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -10,6 +10,7 @@
 #define BUFFERIZATION_OPS
 
 include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
index 9e43265c5dfede..4015231c845daf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
@@ -53,6 +53,7 @@ class BufferViewFlowAnalysis {
   ///
   /// Results in resolve(B) returning {B, C}
   ValueSetT resolve(Value value) const;
+  ValueSetT resolveReverse(Value value) const;
 
   /// Removes the given values from all alias sets.
   void remove(const SetVector<Value> &aliasValues);
@@ -73,11 +74,46 @@ class BufferViewFlowAnalysis {
 
   /// Maps values to all immediate dependencies this value can have.
   ValueMapT dependencies;
+  ValueMapT reverseDependencies;
 
   /// A set of all SSA values that may be terminal buffers.
   DenseSet<Value> terminals;
 };
 
+/// An is-same-buffer analysis that checks if two SSA values belong to the same
+/// buffer allocation or not.
+class BufferOriginAnalysis {
+public:
+  BufferOriginAnalysis(Operation *op);
+
+  /// Return "true" if `v1` and `v2` originate from the same buffer allocation.
+  /// Return "false" if `v1` and `v2` originate from different allocations.
+  /// Return "nullopt" if we do not know for sure.
+  ///
+  /// Example 1: isSameAllocation(%0, %1) == true
+  /// ```
+  /// %0 = memref.alloc()
+  /// %1 = memref.subview %0
+  /// ```
+  ///
+  /// Example 2: isSameAllocation(%0, %1) == false
+  /// ```
+  /// %0 = memref.alloc()
+  /// %1 = memref.alloc()
+  /// ```
+  ///
+  /// Example 3: isSameAllocation(%0, %2) == nullopt
+  /// ```
+  /// %0 = memref.alloc()
+  /// %1 = memref.alloc()
+  /// %2 = arith.select %c, %0, %1
+  /// ```
+  std::optional<bool> isSameAllocation(Value v1, Value v2);
+
+private:
+  BufferViewFlowAnalysis analysis;
+};
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERVIEWFLOWANALYSIS_H
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index e30779868b4753..954485cfede3da 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -12,8 +12,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Analysis/AliasAnalysis.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -34,6 +34,14 @@ using namespace mlir::bufferization;
 // Helpers
 //===----------------------------------------------------------------------===//
 
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+    value = viewLikeOp.getViewSource();
+  return value;
+}
+
 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
                                             ValueRange memrefs,
                                             ValueRange conditions,
@@ -49,14 +57,6 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
   return success();
 }
 
-/// Given a memref value, return the "base" value by skipping over all
-/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
-static Value getViewBase(Value value) {
-  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
-    value = viewLikeOp.getViewSource();
-  return value;
-}
-
 /// Return "true" if the given values are guaranteed to be different (and
 /// non-aliasing) allocations based on the fact that one value is the result
 /// of an allocation and the other value is a block argument of a parent block.
@@ -80,12 +80,14 @@ static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
 /// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
 /// often a requirement of optimization patterns that there cannot be any
 /// aliasing memref in order to perform the desired simplification.
-static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
+static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis,
                                      ValueRange otherList, Value memref) {
   for (auto other : otherList) {
     if (distinctAllocAndBlockArgument(other, memref))
       continue;
-    if (!analysis.alias(other, memref).isNo())
+    std::optional<bool> analysisResult =
+        analysis.isSameAllocation(other, memref);
+    if (!analysisResult.has_value() || analysisResult == true)
       return true;
   }
   return false;
@@ -129,8 +131,8 @@ namespace {
 struct RemoveDeallocMemrefsContainedInRetained
     : public OpRewritePattern<DeallocOp> {
   RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
-                                          AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                          BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   /// The passed 'memref' must not have a may-alias relation to any retained
   /// memref, and at least one must-alias relation. If there is no must-aliasing
@@ -147,10 +149,11 @@ struct RemoveDeallocMemrefsContainedInRetained
     // deallocated in some situations and can thus not be dropped).
     bool atLeastOneMustAlias = false;
     for (Value retained : deallocOp.getRetained()) {
-      AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
-      if (analysisResult.isMay())
+      std::optional<bool> analysisResult =
+          analysis.isSameAllocation(retained, memref);
+      if (!analysisResult.has_value())
         return failure();
-      if (analysisResult.isMust() || analysisResult.isPartial())
+      if (analysisResult == true)
         atLeastOneMustAlias = true;
     }
     if (!atLeastOneMustAlias)
@@ -161,8 +164,9 @@ struct RemoveDeallocMemrefsContainedInRetained
     // we can remove that operand later on.
     for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
       Value updatedCondition = deallocOp.getUpdatedConditions()[i];
-      AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
-      if (analysisResult.isMust() || analysisResult.isPartial()) {
+      std::optional<bool> analysisResult =
+          analysis.isSameAllocation(retained, memref);
+      if (analysisResult == true) {
         auto disjunction = rewriter.create<arith::OrIOp>(
             deallocOp.getLoc(), updatedCondition, cond);
         rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
@@ -206,7 +210,7 @@ struct RemoveDeallocMemrefsContainedInRetained
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 /// Remove memrefs from the `retained` list which are guaranteed to not alias
@@ -228,15 +232,15 @@ struct RemoveDeallocMemrefsContainedInRetained
 struct RemoveRetainedMemrefsGuaranteedToNotAlias
     : public OpRewritePattern<DeallocOp> {
   RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
-                                            AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                            BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Value> newRetainedMemrefs, replacements;
 
     for (auto retainedMemref : deallocOp.getRetained()) {
-      if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
+      if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
                                    retainedMemref)) {
         newRetainedMemrefs.push_back(retainedMemref);
         replacements.push_back({});
@@ -264,7 +268,7 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 /// Split off memrefs to separate dealloc operations to reduce the number of
@@ -297,8 +301,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
 struct SplitDeallocWhenNotAliasingAnyOther
     : public OpRewritePattern<DeallocOp> {
   SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
-                                      AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                      BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
@@ -314,7 +318,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
       SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
       otherMemrefs.erase(otherMemrefs.begin() + i);
       // Check if `memref` can split off into a separate bufferization.dealloc.
-      if (potentiallyAliasesMemref(aliasAnalysis, otherMemrefs, memref)) {
+      if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
         // `memref` alias with other memrefs, do not split off.
         remainingMemrefs.push_back(memref);
         remainingConditions.push_back(cond);
@@ -352,7 +356,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 /// Check for every retained memref if a must-aliasing memref exists in the
@@ -381,8 +385,8 @@ struct SplitDeallocWhenNotAliasingAnyOther
 struct RetainedMemrefAliasingAlwaysDeallocatedMemref
     : public OpRewritePattern<DeallocOp> {
   RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
-                                                AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                                BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
@@ -396,8 +400,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
         if (!matchPattern(cond, m_One()))
           continue;
 
-        AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
-        if (analysisResult.isMust() || analysisResult.isPartial()) {
+        std::optional<bool> analysisResult =
+            analysis.isSameAllocation(retained, memref);
+        if (analysisResult == true) {
           rewriter.replaceAllUsesWith(res, cond);
           aliasesWithConstTrueMemref[i] = true;
           canDropMemref = true;
@@ -411,10 +416,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
         if (!extractOp)
           continue;
 
-        AliasResult extractAnalysisResult =
-            aliasAnalysis.alias(retained, extractOp.getOperand());
-        if (extractAnalysisResult.isMust() ||
-            extractAnalysisResult.isPartial()) {
+        std::optional<bool> extractAnalysisResult =
+            analysis.isSameAllocation(retained, extractOp.getOperand());
+        if (extractAnalysisResult == true) {
           rewriter.replaceAllUsesWith(res, cond);
           aliasesWithConstTrueMemref[i] = true;
           canDropMemref = true;
@@ -434,7 +438,7 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 } // namespace
@@ -452,13 +456,13 @@ struct BufferDeallocationSimplificationPass
     : public bufferization::impl::BufferDeallocationSimplificationBase<
           BufferDeallocationSimplificationPass> {
   void runOnOperation() override {
-    AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
+    BufferOriginAnalysis analysis(getOperation());
     RewritePatternSet patterns(&getContext());
     patterns.add<RemoveDeallocMemrefsContainedInRetained,
                  RemoveRetainedMemrefsGuaranteedToNotAlias,
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
-                                                                aliasAnalysis);
+                                                                analysis);
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
     if (failed(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 9a36057425f366..72f47b8b468ea6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -19,22 +19,23 @@
 using namespace mlir;
 using namespace mlir::bufferization;
 
+//===----------------------------------------------------------------------===//
+// BufferViewFlowAnalysis
+//===----------------------------------------------------------------------===//
+
 /// Constructs a new alias analysis using the op provided.
 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
 
-/// Find all immediate and indirect dependent buffers this value could
-/// potentially have. Note that the resulting set will also contain the value
-/// provided as it is a dependent alias of itself.
-BufferViewFlowAnalysis::ValueSetT
-BufferViewFlowAnalysis::resolve(Value rootValue) const {
-  ValueSetT result;
+static BufferViewFlowAnalysis::ValueSetT
+resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value) {
+  BufferViewFlowAnalysis::ValueSetT result;
   SmallVector<Value, 8> queue;
-  queue.push_back(rootValue);
+  queue.push_back(value);
   while (!queue.empty()) {
     Value currentValue = queue.pop_back_val();
     if (result.insert(currentValue).second) {
-      auto it = dependencies.find(currentValue);
-      if (it != dependencies.end()) {
+      auto it = map.find(currentValue);
+      if (it != map.end()) {
         for (Value aliasValue : it->second)
           queue.push_back(aliasValue);
       }
@@ -43,6 +44,19 @@ BufferViewFlowAnalysis::resolve(Value rootValue) const {
   return result;
 }
 
+/// Find all immediate and indirect dependent buffers this value could
+/// potentially have. Note that the resulting set will also contain the value
+/// provided as it is a dependent alias of itself.
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolve(Value rootValue) const {
+  return resolveValues(dependencies, rootValue);
+}
+
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolveReverse(Value rootValue) const {
+  return resolveValues(reverseDependencies, rootValue);
+}
+
 /// Removes the given values from all alias sets.
 void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
   for (auto &entry : dependencies)
@@ -69,8 +83,10 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
 void BufferViewFlowAnalysis::build(Operation *op) {
   // Registers all dependencies of the given values.
   auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
-    for (auto [value, dep] : llvm::zip_equal(values, dependencies))
+    for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
       this->dependencies[value].insert(dep);
+      this->reverseDependencies[dep].insert(value);
+    }
   };
 
   // Mark all buffer results and buffer region entry block arguments of the
@@ -188,3 +204,127 @@ bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
   assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
   return terminals.contains(value);
 }
+
+//===----------------------------------------------------------------------===//
+// BufferOriginAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Return "true" if the given value is the result of a memory allocation.
+static bool hasAllocateSideEffect(Value v) {
+  Operation *op = v.getDefiningOp();
+  if (!op)
+    return false;
+  return hasEffect<MemoryEffects::Allocate>(op, v);
+}
+
+/// Return "true" if the given value is a function block argument.
+static bool isFunctionArgument(Value v) {
+  auto bbArg = dyn_cast<BlockArgument>(v);
+  if (!bbArg)
+    return false;
+  Block *b = bbArg.getOwner();
+  auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
+  if (!funcOp)
+    return false;
+  return bbArg.getOwner() == &funcOp.getFunctionBody().front();
+}
+
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+    value = viewLikeOp.getViewSource();
+  return value;
+}
+
+BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {}
+
+std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
+  assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
+  assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
+
+  // Skip over all view-like ops.
+  v1 = getViewBase(v1);
+  v2 = getViewBase(v2);
+
+  // Fast path: If both buffers are the same SSA value, we can be sure that
+  // they originate from the same allocation.
+  if (v1 == v2)
+    return true;
+
+  // Compute the SSA values from which the buffers `v1` and `v2` originate.
+  SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1);
+  SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2);
+
+  // Originating buffers are "terminal" if they could not be traced back any
+  // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
+  // - function block arguments
+  // - values defined by allocation ops such as "memref.alloc"
+  // - values defined by ops that are unknown to the buffer view flow analysis
+  // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
+  SmallPtrSet<Value, 16> terminal1, terminal2;
+
+  // While gathering terminal buffers, keep track of whether all terminal
+  // buffers are newly allocated buffer or function entry arguments.
+  bool allAllocs1 = true, allAllocs2 = true;
+  bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
+
+  // Helper function that gathers terminal buffers among `origin`.
+  auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
+                                      SmallPtrSet<Value, 16> &terminal,
+                                      bool &allAllocs,
+                                      bool &allAllocsOrFuncEntryArgs) {
+    for (Value v : origin) {
+      if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
+        terminal.insert(v);
+        allAllocs &= hasAllocateSideEffect(v);
+        allAllocsOrFuncEntryArgs &=
+            isFunctionArgument(v) || hasAllocateSideEffect(v);
+      }
+    }
+    assert(!terminal.empty() && "expected non-empty terminal set");
+  };
+
+  // Gather terminal buffers for `v1` and `v2`.
+  gatherTerminalBuffers(origin1, terminal1, allAllocs1,
+                        allAllocsOrFuncEntryArgs1);
+  gatherTerminalBuffers(origin2, terminal2, allAllocs2,
+                        allAllocsOrFuncEntryArgs2);
+
+  // If both `v1` and `v2` have a single matching terminal buffer, they are
+  // guaranteed to originate from the same buffer allocation.
+  if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) &&
+      *terminal1.begin() == *terminal2.begin())
+    return true;
+
+  // At least one of the two values has multiple terminals.
+
+  // Check if there is overlap between the terminal buffers of `v1` and `v2`.
+  bool distinctTerminalSets = true;
+  for (Value v : terminal1)
+    distinctTerminalSets &= !terminal2.contains(v);
+  // If there is overlap between the terminal buffers of `v1` and `v2`, we
+  // cannot make an accurate decision without further analysis.
+  if (!distinctTerminalSets)
+    return std::nullopt;
+
+  // If `v1` originates from only allocs, and `v2` is guaranteed to originate
+  // from different allocations (that is guaranteed if `v2` originates from
+  // only distinct allocs or function entry arguments), we can be sure that
+  // `v1` and `v2` originate from different allocations. The same argument can
+  // be made when swapping `v1` and `v2`.
+  bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
+  bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
+  if (isolatedAlloc1 || isolatedAlloc2)
+    return false;
+
+  // Otherwise: We do not know whether `v1` and `v2` originate from the same
+  // allocation or not.
+  // TODO: Function arguments are currently handled conservatively. We assume
+  // that they could be the same allocation.
+  // TODO: Terminals other than allocations and function arguments are
+  // currently handled conservatively. We assume that they could be the same
+  // allocation. E.g., we currently return "nullopt" for values that originate
+  // from different "memref.get_global" ops (with different symbols).
+  return std::nullopt;
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-loops.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-loops.mlir
new file mode 100644
index 00000000000000..53b28c3aab6fd8
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-loops.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s -expand-realloc="emit-deallocs=false" -ownership-based-buffer-deallocation="private-function-dynamic-ownership=true" -canonicalize -buffer-deallocation-simplification | FileCheck %s
+
+// A function that reallocates two buffer inside of a loop. The simplification
+// pass should be able to figure out that the iter_args are always originating
+// from different allocations. IR like this one appears in the sparse compiler.
+
+// CHECK-LABEL: func private @loop_with_realloc(
+func.func private @loop_with_realloc(%lb: index, %ub: index, %step: index, %c: i1, %s1: index, %s2: index) -> (memref<?xf32>, memref<?xf32>) {
+  // CHECK-DAG: %[[false:.*]] = arith.constant false
+  // CHECK-DAG: %[[true:.*]] = arith.constant true
+
+  // CHECK: %[[m0:.*]] = memref.alloc
+  %m0 = memref.alloc(%s1) : memref<?xf32>
+  // CHECK: %[[m1:.*]] = memref.alloc
+  %m1 = memref.alloc(%s1) : memref<?xf32>
+
+  // CHECK: %[[r:.*]]:4 = scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[m0]], %[[arg1:.*]] = %[[m1]], %[[o0:.*]] = %[[false]], %[[o1:.*]] = %[[false]])
+  %r0, %r1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %m0, %arg1 = %m1) -> (memref<?xf32>, memref<?xf32>) {
+    //      CHECK: %[[m2:.*]]:2 = scf.if %{{.*}} -> (memref<?xf32>, i1) {
+    // CHECK-NEXT:   memref.alloc
+    // CHECK-NEXT:   memref.subview
+    // CHECK-NEXT:   memref.copy
+    // CHECK-NEXT:   scf.yield %{{.*}}, %[[true]]
+    // CHECK-NEXT: } else {
+    // CHECK-NEXT:   memref.reinterpret_cast
+    // CHECK-NEXT:   scf.yield %{{.*}}, %[[false]]
+    // CHECK-NEXT: }
+    %m2 = memref.realloc %arg0(%s2) : memref<?xf32> to memref<?xf32>
+    //      CHECK: %[[m3:.*]]:2 = scf.if %{{.*}} -> (memref<?xf32>, i1) {
+    // CHECK-NEXT:   memref.alloc
+    // CHECK-NEXT:   memref.subview
+    // CHECK-NEXT:   memref.copy
+    // CHECK-NEXT:   scf.yield %{{.*}}, %[[true]]
+    // CHECK-NEXT: } else {
+    // CHECK-NEXT:   memref.reinterpret_cast
+    // CHECK-NEXT:   scf.yield %{{.*}}, %[[false]]
+    // CHECK-NEXT: }
+    %m3 = memref.realloc %arg1(%s2) : memref<?xf32> to memref<?xf32>
+
+    // CHECK: %[[base0:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[arg0]]
+    // CHECK: %[[base1:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[arg1]]
+    // CHECK: %[[d0:.*]] = bufferization.dealloc (%[[base0]] : memref<f32>) if (%[[o0]]) retain (%[[m2]]#0 : memref<?xf32>)
+    // CHECK: %[[d1:.*]] = bufferization.dealloc (%[[base1]] : memref<f32>) if (%[[o1]]) retain (%[[m3]]#0 : memref<?xf32>)
+    // CHECK-DAG: %[[o2:.*]] = arith.ori %[[d0]], %[[m2]]#1
+    // CHECK-DAG: %[[o3:.*]] = arith.ori %[[d1]], %[[m3]]#1
+    // CHECK: scf.yield %[[m2]]#0, %[[m3]]#0, %[[o2]], %[[o3]]
+    scf.yield %m2, %m3 : memref<?xf32>, memref<?xf32>
+  }
+
+  // CHECK: %[[d2:.*]] = bufferization.dealloc (%[[m0]] : memref<?xf32>) if (%[[true]]) retain (%[[r]]#0 : memref<?xf32>)
+  // CHECK: %[[d3:.*]] = bufferization.dealloc (%[[m1]] : memref<?xf32>) if (%[[true]]) retain (%[[r]]#1 : memref<?xf32>)
+  // CHECK-DAG: %[[or0:.*]] = arith.ori %[[d2]], %[[r]]#2
+  // CHECK-DAG: %[[or1:.*]] = arith.ori %[[d3]], %[[r]]#3
+  // CHECK: return %[[r]]#0, %[[r]]#1, %[[or0]], %[[or1]]
+  return %r0, %r1 : memref<?xf32>, memref<?xf32>
+}
+
+// -----
+
+// The yielded values of the loop are swapped. Therefore, the
+// bufferization.dealloc before the func.return can now longer be split,
+// because %r0 could originate from either %m0 and %m1 (same for %r1).
+
+// CHECK-LABEL: func private @swapping_loop_with_realloc(
+func.func private @swapping_loop_with_realloc(%lb: index, %ub: index, %step: index, %c: i1, %s1: index, %s2: index) -> (memref<?xf32>, memref<?xf32>) {
+  // CHECK-DAG: %[[false:.*]] = arith.constant false
+  // CHECK-DAG: %[[true:.*]] = arith.constant true
+
+  // CHECK: %[[m0:.*]] = memref.alloc
+  %m0 = memref.alloc(%s1) : memref<?xf32>
+  // CHECK: %[[m1:.*]] = memref.alloc
+  %m1 = memref.alloc(%s1) : memref<?xf32>
+
+  // CHECK: %[[r:.*]]:4 = scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[m0]], %[[arg1:.*]] = %[[m1]], %[[o0:.*]] = %[[false]], %[[o1:.*]] = %[[false]])
+  %r0, %r1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %m0, %arg1 = %m1) -> (memref<?xf32>, memref<?xf32>) {
+    %m2 = memref.realloc %arg0(%s2) : memref<?xf32> to memref<?xf32>
+    %m3 = memref.realloc %arg1(%s2) : memref<?xf32> to memref<?xf32>
+    scf.yield %m3, %m2 : memref<?xf32>, memref<?xf32>
+  }
+
+  // CHECK: %[[base0:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[r]]#0
+  // CHECK: %[[base1:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[r]]#1
+  // CHECK: %[[d:.*]]:2 = bufferization.dealloc (%[[m0]], %[[m1]], %[[base0]], %[[base1]] : {{.*}}) if (%[[true]], %[[true]], %[[r]]#2, %[[r]]#3) retain (%[[r]]#0, %[[r]]#1 : {{.*}})
+  // CHECK: return %[[r]]#0, %[[r]]#1, %[[d]]#0, %[[d]]#1
+  return %r0, %r1 : memref<?xf32>, memref<?xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
index eee69acbe821b3..b40a17cf800bf3 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -92,15 +92,13 @@ func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>,
 //  CHECK-NEXT:   [[ALLOC0:%.+]] = memref.alloc(
 //  CHECK-NEXT:   [[ALLOC1:%.+]] = memref.alloc(
 //  CHECK-NEXT:   [[V0:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]] :
-// COM: there is only one value in the retained list because the
-// COM: RemoveRetainedMemrefsGuaranteedToNotAlias pattern also applies here and
-// COM: removes %arg1 from the list. In the second dealloc, this does not apply
-// COM: because function arguments are assumed potentially alias (even if the
-// COM: types don't exactly match).
+// COM: there is only one value in the retained lists because the
+// COM: RemoveRetainedMemrefsGuaranteedToNotAlias pattern also applies here:
+// COM: - %alloc is guaranteed to not alias with %arg1.
+// COM: - %arg2 is guaranteed to not alias with %0.
 //  CHECK-NEXT:   [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>)
-//  CHECK-NEXT:   [[V2:%.+]]:2 = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]], [[V0]] : memref<2xi32>, memref<2xi32>)
-//  CHECK-NEXT:   [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1
-//  CHECK-NEXT:   return [[V2]]#0, [[V3]] :
+//  CHECK-NEXT:   [[V2:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]] : memref<2xi32>)
+//  CHECK-NEXT:   return [[V2]], [[V1]] :
 
 // -----
 



More information about the Mlir-commits mailing list