[Mlir-commits] [mlir] dbfc38e - [mlir][bufferization] Add `BufferOriginAnalysis` (#86461)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 25 02:57:57 PDT 2024


Author: Matthias Springer
Date: 2024-03-25T18:57:53+09:00
New Revision: dbfc38ed6b3f2a9be0b1a86b2a074aad69eb58a6

URL: https://github.com/llvm/llvm-project/commit/dbfc38ed6b3f2a9be0b1a86b2a074aad69eb58a6
DIFF: https://github.com/llvm/llvm-project/commit/dbfc38ed6b3f2a9be0b1a86b2a074aad69eb58a6.diff

LOG: [mlir][bufferization] Add `BufferOriginAnalysis` (#86461)

This commit adds the `BufferOriginAnalysis`, which can be queried to
check if two buffer SSA values originate from the same allocation. This
new analysis is used in the buffer deallocation pass to fold away or
simplify `bufferization.dealloc` ops more aggressively.

The `BufferOriginAnalysis` is based on the `BufferViewFlowAnalysis`,
which collects buffer SSA value "same buffer" dependencies. E.g., given
IR such as:
```
%0 = memref.alloc()
%1 = memref.subview %0
%2 = memref.subview %1
```
The `BufferViewFlowAnalysis` will report the following "reverse"
dependencies (`resolveReverse`) for `%2`: {`%2`, `%1`, `%0`}. I.e., all
buffer SSA values in the reverse use-def chain that originate from the
same allocation as `%2`. The `BufferOriginAnalysis` is built on top of
that. It handles only simple cases at the moment and may conservatively
return "unknown" around certain IR with branches, memref globals and
function arguments.

This analysis enables additional simplifications during
`-buffer-deallocation-simplification`. In particular, "regular" scf.for
loop nests, that yield buffers (or reallocations thereof) in the same
order as they appear in the iter_args, are now handled much more
efficiently. Such IR patterns are generated by the sparse compiler.

Added: 
    mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-loops.mlir

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
    mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
    mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
    mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 9dc6afcaab31c8..4f609ddff9a413 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -10,6 +10,7 @@
 #define BUFFERIZATION_OPS
 
 include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
index 9e43265c5dfede..4015231c845daf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
@@ -53,6 +53,7 @@ class BufferViewFlowAnalysis {
   ///
   /// Results in resolve(B) returning {B, C}
   ValueSetT resolve(Value value) const;
+  ValueSetT resolveReverse(Value value) const;
 
   /// Removes the given values from all alias sets.
   void remove(const SetVector<Value> &aliasValues);
@@ -73,11 +74,46 @@ class BufferViewFlowAnalysis {
 
   /// Maps values to all immediate dependencies this value can have.
   ValueMapT dependencies;
+  ValueMapT reverseDependencies;
 
   /// A set of all SSA values that may be terminal buffers.
   DenseSet<Value> terminals;
 };
 
+/// An is-same-buffer analysis that checks if two SSA values belong to the same
+/// buffer allocation or not.
+class BufferOriginAnalysis {
+public:
+  BufferOriginAnalysis(Operation *op);
+
+  /// Return "true" if `v1` and `v2` originate from the same buffer allocation.
+  /// Return "false" if `v1` and `v2` originate from 
diff erent allocations.
+  /// Return "nullopt" if we do not know for sure.
+  ///
+  /// Example 1: isSameAllocation(%0, %1) == true
+  /// ```
+  /// %0 = memref.alloc()
+  /// %1 = memref.subview %0
+  /// ```
+  ///
+  /// Example 2: isSameAllocation(%0, %1) == false
+  /// ```
+  /// %0 = memref.alloc()
+  /// %1 = memref.alloc()
+  /// ```
+  ///
+  /// Example 3: isSameAllocation(%0, %2) == nullopt
+  /// ```
+  /// %0 = memref.alloc()
+  /// %1 = memref.alloc()
+  /// %2 = arith.select %c, %0, %1
+  /// ```
+  std::optional<bool> isSameAllocation(Value v1, Value v2);
+
+private:
+  BufferViewFlowAnalysis analysis;
+};
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERVIEWFLOWANALYSIS_H

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index e30779868b4753..954485cfede3da 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -12,8 +12,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Analysis/AliasAnalysis.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -34,6 +34,14 @@ using namespace mlir::bufferization;
 // Helpers
 //===----------------------------------------------------------------------===//
 
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+    value = viewLikeOp.getViewSource();
+  return value;
+}
+
 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
                                             ValueRange memrefs,
                                             ValueRange conditions,
@@ -49,14 +57,6 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
   return success();
 }
 
-/// Given a memref value, return the "base" value by skipping over all
-/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
-static Value getViewBase(Value value) {
-  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
-    value = viewLikeOp.getViewSource();
-  return value;
-}
-
 /// Return "true" if the given values are guaranteed to be 
diff erent (and
 /// non-aliasing) allocations based on the fact that one value is the result
 /// of an allocation and the other value is a block argument of a parent block.
@@ -80,12 +80,14 @@ static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
 /// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
 /// often a requirement of optimization patterns that there cannot be any
 /// aliasing memref in order to perform the desired simplification.
-static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
+static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis,
                                      ValueRange otherList, Value memref) {
   for (auto other : otherList) {
     if (distinctAllocAndBlockArgument(other, memref))
       continue;
-    if (!analysis.alias(other, memref).isNo())
+    std::optional<bool> analysisResult =
+        analysis.isSameAllocation(other, memref);
+    if (!analysisResult.has_value() || analysisResult == true)
       return true;
   }
   return false;
@@ -129,8 +131,8 @@ namespace {
 struct RemoveDeallocMemrefsContainedInRetained
     : public OpRewritePattern<DeallocOp> {
   RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
-                                          AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                          BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   /// The passed 'memref' must not have a may-alias relation to any retained
   /// memref, and at least one must-alias relation. If there is no must-aliasing
@@ -147,10 +149,11 @@ struct RemoveDeallocMemrefsContainedInRetained
     // deallocated in some situations and can thus not be dropped).
     bool atLeastOneMustAlias = false;
     for (Value retained : deallocOp.getRetained()) {
-      AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
-      if (analysisResult.isMay())
+      std::optional<bool> analysisResult =
+          analysis.isSameAllocation(retained, memref);
+      if (!analysisResult.has_value())
         return failure();
-      if (analysisResult.isMust() || analysisResult.isPartial())
+      if (analysisResult == true)
         atLeastOneMustAlias = true;
     }
     if (!atLeastOneMustAlias)
@@ -161,8 +164,9 @@ struct RemoveDeallocMemrefsContainedInRetained
     // we can remove that operand later on.
     for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
       Value updatedCondition = deallocOp.getUpdatedConditions()[i];
-      AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
-      if (analysisResult.isMust() || analysisResult.isPartial()) {
+      std::optional<bool> analysisResult =
+          analysis.isSameAllocation(retained, memref);
+      if (analysisResult == true) {
         auto disjunction = rewriter.create<arith::OrIOp>(
             deallocOp.getLoc(), updatedCondition, cond);
         rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
@@ -206,7 +210,7 @@ struct RemoveDeallocMemrefsContainedInRetained
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 /// Remove memrefs from the `retained` list which are guaranteed to not alias
@@ -228,15 +232,15 @@ struct RemoveDeallocMemrefsContainedInRetained
 struct RemoveRetainedMemrefsGuaranteedToNotAlias
     : public OpRewritePattern<DeallocOp> {
   RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
-                                            AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                            BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Value> newRetainedMemrefs, replacements;
 
     for (auto retainedMemref : deallocOp.getRetained()) {
-      if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
+      if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
                                    retainedMemref)) {
         newRetainedMemrefs.push_back(retainedMemref);
         replacements.push_back({});
@@ -264,7 +268,7 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 /// Split off memrefs to separate dealloc operations to reduce the number of
@@ -297,8 +301,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
 struct SplitDeallocWhenNotAliasingAnyOther
     : public OpRewritePattern<DeallocOp> {
   SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
-                                      AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                      BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
@@ -314,7 +318,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
       SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
       otherMemrefs.erase(otherMemrefs.begin() + i);
       // Check if `memref` can split off into a separate bufferization.dealloc.
-      if (potentiallyAliasesMemref(aliasAnalysis, otherMemrefs, memref)) {
+      if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
         // `memref` alias with other memrefs, do not split off.
         remainingMemrefs.push_back(memref);
         remainingConditions.push_back(cond);
@@ -352,7 +356,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 /// Check for every retained memref if a must-aliasing memref exists in the
@@ -381,8 +385,8 @@ struct SplitDeallocWhenNotAliasingAnyOther
 struct RetainedMemrefAliasingAlwaysDeallocatedMemref
     : public OpRewritePattern<DeallocOp> {
   RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
-                                                AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                                BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
@@ -396,8 +400,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
         if (!matchPattern(cond, m_One()))
           continue;
 
-        AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
-        if (analysisResult.isMust() || analysisResult.isPartial()) {
+        std::optional<bool> analysisResult =
+            analysis.isSameAllocation(retained, memref);
+        if (analysisResult == true) {
           rewriter.replaceAllUsesWith(res, cond);
           aliasesWithConstTrueMemref[i] = true;
           canDropMemref = true;
@@ -411,10 +416,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
         if (!extractOp)
           continue;
 
-        AliasResult extractAnalysisResult =
-            aliasAnalysis.alias(retained, extractOp.getOperand());
-        if (extractAnalysisResult.isMust() ||
-            extractAnalysisResult.isPartial()) {
+        std::optional<bool> extractAnalysisResult =
+            analysis.isSameAllocation(retained, extractOp.getOperand());
+        if (extractAnalysisResult == true) {
           rewriter.replaceAllUsesWith(res, cond);
           aliasesWithConstTrueMemref[i] = true;
           canDropMemref = true;
@@ -434,7 +438,7 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 } // namespace
@@ -452,13 +456,13 @@ struct BufferDeallocationSimplificationPass
     : public bufferization::impl::BufferDeallocationSimplificationBase<
           BufferDeallocationSimplificationPass> {
   void runOnOperation() override {
-    AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
+    BufferOriginAnalysis analysis(getOperation());
     RewritePatternSet patterns(&getContext());
     patterns.add<RemoveDeallocMemrefsContainedInRetained,
                  RemoveRetainedMemrefsGuaranteedToNotAlias,
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
-                                                                aliasAnalysis);
+                                                                analysis);
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
     if (failed(

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 9a36057425f366..72f47b8b468ea6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -19,22 +19,23 @@
 using namespace mlir;
 using namespace mlir::bufferization;
 
+//===----------------------------------------------------------------------===//
+// BufferViewFlowAnalysis
+//===----------------------------------------------------------------------===//
+
 /// Constructs a new alias analysis using the op provided.
 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
 
-/// Find all immediate and indirect dependent buffers this value could
-/// potentially have. Note that the resulting set will also contain the value
-/// provided as it is a dependent alias of itself.
-BufferViewFlowAnalysis::ValueSetT
-BufferViewFlowAnalysis::resolve(Value rootValue) const {
-  ValueSetT result;
+static BufferViewFlowAnalysis::ValueSetT
+resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value) {
+  BufferViewFlowAnalysis::ValueSetT result;
   SmallVector<Value, 8> queue;
-  queue.push_back(rootValue);
+  queue.push_back(value);
   while (!queue.empty()) {
     Value currentValue = queue.pop_back_val();
     if (result.insert(currentValue).second) {
-      auto it = dependencies.find(currentValue);
-      if (it != dependencies.end()) {
+      auto it = map.find(currentValue);
+      if (it != map.end()) {
         for (Value aliasValue : it->second)
           queue.push_back(aliasValue);
       }
@@ -43,6 +44,19 @@ BufferViewFlowAnalysis::resolve(Value rootValue) const {
   return result;
 }
 
+/// Find all immediate and indirect dependent buffers this value could
+/// potentially have. Note that the resulting set will also contain the value
+/// provided as it is a dependent alias of itself.
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolve(Value rootValue) const {
+  return resolveValues(dependencies, rootValue);
+}
+
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolveReverse(Value rootValue) const {
+  return resolveValues(reverseDependencies, rootValue);
+}
+
 /// Removes the given values from all alias sets.
 void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
   for (auto &entry : dependencies)
@@ -69,8 +83,10 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
 void BufferViewFlowAnalysis::build(Operation *op) {
   // Registers all dependencies of the given values.
   auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
-    for (auto [value, dep] : llvm::zip_equal(values, dependencies))
+    for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
       this->dependencies[value].insert(dep);
+      this->reverseDependencies[dep].insert(value);
+    }
   };
 
   // Mark all buffer results and buffer region entry block arguments of the
@@ -188,3 +204,127 @@ bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
   assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
   return terminals.contains(value);
 }
+
+//===----------------------------------------------------------------------===//
+// BufferOriginAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Return "true" if the given value is the result of a memory allocation.
+static bool hasAllocateSideEffect(Value v) {
+  Operation *op = v.getDefiningOp();
+  if (!op)
+    return false;
+  return hasEffect<MemoryEffects::Allocate>(op, v);
+}
+
+/// Return "true" if the given value is a function block argument.
+static bool isFunctionArgument(Value v) {
+  auto bbArg = dyn_cast<BlockArgument>(v);
+  if (!bbArg)
+    return false;
+  Block *b = bbArg.getOwner();
+  auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
+  if (!funcOp)
+    return false;
+  return bbArg.getOwner() == &funcOp.getFunctionBody().front();
+}
+
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+    value = viewLikeOp.getViewSource();
+  return value;
+}
+
+BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {}
+
+std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
+  assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
+  assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
+
+  // Skip over all view-like ops.
+  v1 = getViewBase(v1);
+  v2 = getViewBase(v2);
+
+  // Fast path: If both buffers are the same SSA value, we can be sure that
+  // they originate from the same allocation.
+  if (v1 == v2)
+    return true;
+
+  // Compute the SSA values from which the buffers `v1` and `v2` originate.
+  SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1);
+  SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2);
+
+  // Originating buffers are "terminal" if they could not be traced back any
+  // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
+  // - function block arguments
+  // - values defined by allocation ops such as "memref.alloc"
+  // - values defined by ops that are unknown to the buffer view flow analysis
+  // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
+  SmallPtrSet<Value, 16> terminal1, terminal2;
+
+  // While gathering terminal buffers, keep track of whether all terminal
+  // buffers are newly allocated buffer or function entry arguments.
+  bool allAllocs1 = true, allAllocs2 = true;
+  bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
+
+  // Helper function that gathers terminal buffers among `origin`.
+  auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
+                                      SmallPtrSet<Value, 16> &terminal,
+                                      bool &allAllocs,
+                                      bool &allAllocsOrFuncEntryArgs) {
+    for (Value v : origin) {
+      if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
+        terminal.insert(v);
+        allAllocs &= hasAllocateSideEffect(v);
+        allAllocsOrFuncEntryArgs &=
+            isFunctionArgument(v) || hasAllocateSideEffect(v);
+      }
+    }
+    assert(!terminal.empty() && "expected non-empty terminal set");
+  };
+
+  // Gather terminal buffers for `v1` and `v2`.
+  gatherTerminalBuffers(origin1, terminal1, allAllocs1,
+                        allAllocsOrFuncEntryArgs1);
+  gatherTerminalBuffers(origin2, terminal2, allAllocs2,
+                        allAllocsOrFuncEntryArgs2);
+
+  // If both `v1` and `v2` have a single matching terminal buffer, they are
+  // guaranteed to originate from the same buffer allocation.
+  if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) &&
+      *terminal1.begin() == *terminal2.begin())
+    return true;
+
+  // At least one of the two values has multiple terminals.
+
+  // Check if there is overlap between the terminal buffers of `v1` and `v2`.
+  bool distinctTerminalSets = true;
+  for (Value v : terminal1)
+    distinctTerminalSets &= !terminal2.contains(v);
+  // If there is overlap between the terminal buffers of `v1` and `v2`, we
+  // cannot make an accurate decision without further analysis.
+  if (!distinctTerminalSets)
+    return std::nullopt;
+
+  // If `v1` originates from only allocs, and `v2` is guaranteed to originate
+  // from 
diff erent allocations (that is guaranteed if `v2` originates from
+  // only distinct allocs or function entry arguments), we can be sure that
+  // `v1` and `v2` originate from 
diff erent allocations. The same argument can
+  // be made when swapping `v1` and `v2`.
+  bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
+  bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
+  if (isolatedAlloc1 || isolatedAlloc2)
+    return false;
+
+  // Otherwise: We do not know whether `v1` and `v2` originate from the same
+  // allocation or not.
+  // TODO: Function arguments are currently handled conservatively. We assume
+  // that they could be the same allocation.
+  // TODO: Terminals other than allocations and function arguments are
+  // currently handled conservatively. We assume that they could be the same
+  // allocation. E.g., we currently return "nullopt" for values that originate
+  // from 
diff erent "memref.get_global" ops (with 
diff erent symbols).
+  return std::nullopt;
+}

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-loops.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-loops.mlir
new file mode 100644
index 00000000000000..53b28c3aab6fd8
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-loops.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s -expand-realloc="emit-deallocs=false" -ownership-based-buffer-deallocation="private-function-dynamic-ownership=true" -canonicalize -buffer-deallocation-simplification | FileCheck %s
+
+// A function that reallocates two buffer inside of a loop. The simplification
+// pass should be able to figure out that the iter_args are always originating
+// from 
diff erent allocations. IR like this one appears in the sparse compiler.
+
+// CHECK-LABEL: func private @loop_with_realloc(
+func.func private @loop_with_realloc(%lb: index, %ub: index, %step: index, %c: i1, %s1: index, %s2: index) -> (memref<?xf32>, memref<?xf32>) {
+  // CHECK-DAG: %[[false:.*]] = arith.constant false
+  // CHECK-DAG: %[[true:.*]] = arith.constant true
+
+  // CHECK: %[[m0:.*]] = memref.alloc
+  %m0 = memref.alloc(%s1) : memref<?xf32>
+  // CHECK: %[[m1:.*]] = memref.alloc
+  %m1 = memref.alloc(%s1) : memref<?xf32>
+
+  // CHECK: %[[r:.*]]:4 = scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[m0]], %[[arg1:.*]] = %[[m1]], %[[o0:.*]] = %[[false]], %[[o1:.*]] = %[[false]])
+  %r0, %r1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %m0, %arg1 = %m1) -> (memref<?xf32>, memref<?xf32>) {
+    //      CHECK: %[[m2:.*]]:2 = scf.if %{{.*}} -> (memref<?xf32>, i1) {
+    // CHECK-NEXT:   memref.alloc
+    // CHECK-NEXT:   memref.subview
+    // CHECK-NEXT:   memref.copy
+    // CHECK-NEXT:   scf.yield %{{.*}}, %[[true]]
+    // CHECK-NEXT: } else {
+    // CHECK-NEXT:   memref.reinterpret_cast
+    // CHECK-NEXT:   scf.yield %{{.*}}, %[[false]]
+    // CHECK-NEXT: }
+    %m2 = memref.realloc %arg0(%s2) : memref<?xf32> to memref<?xf32>
+    //      CHECK: %[[m3:.*]]:2 = scf.if %{{.*}} -> (memref<?xf32>, i1) {
+    // CHECK-NEXT:   memref.alloc
+    // CHECK-NEXT:   memref.subview
+    // CHECK-NEXT:   memref.copy
+    // CHECK-NEXT:   scf.yield %{{.*}}, %[[true]]
+    // CHECK-NEXT: } else {
+    // CHECK-NEXT:   memref.reinterpret_cast
+    // CHECK-NEXT:   scf.yield %{{.*}}, %[[false]]
+    // CHECK-NEXT: }
+    %m3 = memref.realloc %arg1(%s2) : memref<?xf32> to memref<?xf32>
+
+    // CHECK: %[[base0:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[arg0]]
+    // CHECK: %[[base1:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[arg1]]
+    // CHECK: %[[d0:.*]] = bufferization.dealloc (%[[base0]] : memref<f32>) if (%[[o0]]) retain (%[[m2]]#0 : memref<?xf32>)
+    // CHECK: %[[d1:.*]] = bufferization.dealloc (%[[base1]] : memref<f32>) if (%[[o1]]) retain (%[[m3]]#0 : memref<?xf32>)
+    // CHECK-DAG: %[[o2:.*]] = arith.ori %[[d0]], %[[m2]]#1
+    // CHECK-DAG: %[[o3:.*]] = arith.ori %[[d1]], %[[m3]]#1
+    // CHECK: scf.yield %[[m2]]#0, %[[m3]]#0, %[[o2]], %[[o3]]
+    scf.yield %m2, %m3 : memref<?xf32>, memref<?xf32>
+  }
+
+  // CHECK: %[[d2:.*]] = bufferization.dealloc (%[[m0]] : memref<?xf32>) if (%[[true]]) retain (%[[r]]#0 : memref<?xf32>)
+  // CHECK: %[[d3:.*]] = bufferization.dealloc (%[[m1]] : memref<?xf32>) if (%[[true]]) retain (%[[r]]#1 : memref<?xf32>)
+  // CHECK-DAG: %[[or0:.*]] = arith.ori %[[d2]], %[[r]]#2
+  // CHECK-DAG: %[[or1:.*]] = arith.ori %[[d3]], %[[r]]#3
+  // CHECK: return %[[r]]#0, %[[r]]#1, %[[or0]], %[[or1]]
+  return %r0, %r1 : memref<?xf32>, memref<?xf32>
+}
+
+// -----
+
+// The yielded values of the loop are swapped. Therefore, the
+// bufferization.dealloc before the func.return can now longer be split,
+// because %r0 could originate from either %m0 and %m1 (same for %r1).
+
+// CHECK-LABEL: func private @swapping_loop_with_realloc(
+func.func private @swapping_loop_with_realloc(%lb: index, %ub: index, %step: index, %c: i1, %s1: index, %s2: index) -> (memref<?xf32>, memref<?xf32>) {
+  // CHECK-DAG: %[[false:.*]] = arith.constant false
+  // CHECK-DAG: %[[true:.*]] = arith.constant true
+
+  // CHECK: %[[m0:.*]] = memref.alloc
+  %m0 = memref.alloc(%s1) : memref<?xf32>
+  // CHECK: %[[m1:.*]] = memref.alloc
+  %m1 = memref.alloc(%s1) : memref<?xf32>
+
+  // CHECK: %[[r:.*]]:4 = scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[m0]], %[[arg1:.*]] = %[[m1]], %[[o0:.*]] = %[[false]], %[[o1:.*]] = %[[false]])
+  %r0, %r1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %m0, %arg1 = %m1) -> (memref<?xf32>, memref<?xf32>) {
+    %m2 = memref.realloc %arg0(%s2) : memref<?xf32> to memref<?xf32>
+    %m3 = memref.realloc %arg1(%s2) : memref<?xf32> to memref<?xf32>
+    scf.yield %m3, %m2 : memref<?xf32>, memref<?xf32>
+  }
+
+  // CHECK: %[[base0:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[r]]#0
+  // CHECK: %[[base1:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[r]]#1
+  // CHECK: %[[d:.*]]:2 = bufferization.dealloc (%[[m0]], %[[m1]], %[[base0]], %[[base1]] : {{.*}}) if (%[[true]], %[[true]], %[[r]]#2, %[[r]]#3) retain (%[[r]]#0, %[[r]]#1 : {{.*}})
+  // CHECK: return %[[r]]#0, %[[r]]#1, %[[d]]#0, %[[d]]#1
+  return %r0, %r1 : memref<?xf32>, memref<?xf32>
+}

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
index eee69acbe821b3..b40a17cf800bf3 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -92,15 +92,13 @@ func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>,
 //  CHECK-NEXT:   [[ALLOC0:%.+]] = memref.alloc(
 //  CHECK-NEXT:   [[ALLOC1:%.+]] = memref.alloc(
 //  CHECK-NEXT:   [[V0:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]] :
-// COM: there is only one value in the retained list because the
-// COM: RemoveRetainedMemrefsGuaranteedToNotAlias pattern also applies here and
-// COM: removes %arg1 from the list. In the second dealloc, this does not apply
-// COM: because function arguments are assumed potentially alias (even if the
-// COM: types don't exactly match).
+// COM: there is only one value in the retained lists because the
+// COM: RemoveRetainedMemrefsGuaranteedToNotAlias pattern also applies here:
+// COM: - %alloc is guaranteed to not alias with %arg1.
+// COM: - %arg2 is guaranteed to not alias with %0.
 //  CHECK-NEXT:   [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>)
-//  CHECK-NEXT:   [[V2:%.+]]:2 = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]], [[V0]] : memref<2xi32>, memref<2xi32>)
-//  CHECK-NEXT:   [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1
-//  CHECK-NEXT:   return [[V2]]#0, [[V3]] :
+//  CHECK-NEXT:   [[V2:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]] : memref<2xi32>)
+//  CHECK-NEXT:   return [[V2]], [[V1]] :
 
 // -----
 


        


More information about the Mlir-commits mailing list