[Mlir-commits] [mlir] [mlir][bufferization] Add `BufferOriginAnalysis` (PR #86461)
Matthias Springer
llvmlistbot at llvm.org
Sun Mar 24 19:58:36 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/86461
>From be4524ef37903229a57159a6d7b668c83551b7a3 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 25 Mar 2024 02:57:40 +0000
Subject: [PATCH] [mlir][bufferization] BufferOriginAnalysis
---
.../Bufferization/IR/BufferizationOps.td | 1 +
.../Transforms/BufferViewFlowAnalysis.h | 36 ++++
.../BufferDeallocationSimplification.cpp | 80 ++++-----
.../Transforms/BufferViewFlowAnalysis.cpp | 160 ++++++++++++++++--
.../dealloc-loops.mlir | 86 ++++++++++
.../buffer-deallocation-simplification.mlir | 14 +-
6 files changed, 321 insertions(+), 56 deletions(-)
create mode 100644 mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-loops.mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 9dc6afcaab31c8..4f609ddff9a413 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -10,6 +10,7 @@
#define BUFFERIZATION_OPS
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
index 9e43265c5dfede..4015231c845daf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
@@ -53,6 +53,7 @@ class BufferViewFlowAnalysis {
///
/// Results in resolve(B) returning {B, C}
ValueSetT resolve(Value value) const;
+ ValueSetT resolveReverse(Value value) const;
/// Removes the given values from all alias sets.
void remove(const SetVector<Value> &aliasValues);
@@ -73,11 +74,46 @@ class BufferViewFlowAnalysis {
/// Maps values to all immediate dependencies this value can have.
ValueMapT dependencies;
+ ValueMapT reverseDependencies;
/// A set of all SSA values that may be terminal buffers.
DenseSet<Value> terminals;
};
+/// An is-same-buffer analysis that checks if two SSA values belong to the same
+/// buffer allocation or not.
+class BufferOriginAnalysis {
+public:
+ BufferOriginAnalysis(Operation *op);
+
+ /// Return "true" if `v1` and `v2` originate from the same buffer allocation.
+ /// Return "false" if `v1` and `v2` originate from different allocations.
+ /// Return "nullopt" if we do not know for sure.
+ ///
+ /// Example 1: isSameAllocation(%0, %1) == true
+ /// ```
+ /// %0 = memref.alloc()
+ /// %1 = memref.subview %0
+ /// ```
+ ///
+ /// Example 2: isSameAllocation(%0, %1) == false
+ /// ```
+ /// %0 = memref.alloc()
+ /// %1 = memref.alloc()
+ /// ```
+ ///
+ /// Example 3: isSameAllocation(%0, %2) == nullopt
+ /// ```
+ /// %0 = memref.alloc()
+ /// %1 = memref.alloc()
+ /// %2 = arith.select %c, %0, %1
+ /// ```
+ std::optional<bool> isSameAllocation(Value v1, Value v2);
+
+private:
+ BufferViewFlowAnalysis analysis;
+};
+
} // namespace mlir
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERVIEWFLOWANALYSIS_H
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index e30779868b4753..954485cfede3da 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -12,8 +12,8 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -34,6 +34,14 @@ using namespace mlir::bufferization;
// Helpers
//===----------------------------------------------------------------------===//
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+ while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+ value = viewLikeOp.getViewSource();
+ return value;
+}
+
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
ValueRange memrefs,
ValueRange conditions,
@@ -49,14 +57,6 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
return success();
}
-/// Given a memref value, return the "base" value by skipping over all
-/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
-static Value getViewBase(Value value) {
- while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
- value = viewLikeOp.getViewSource();
- return value;
-}
-
/// Return "true" if the given values are guaranteed to be different (and
/// non-aliasing) allocations based on the fact that one value is the result
/// of an allocation and the other value is a block argument of a parent block.
@@ -80,12 +80,14 @@ static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
/// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
/// often a requirement of optimization patterns that there cannot be any
/// aliasing memref in order to perform the desired simplification.
-static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
+static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis,
ValueRange otherList, Value memref) {
for (auto other : otherList) {
if (distinctAllocAndBlockArgument(other, memref))
continue;
- if (!analysis.alias(other, memref).isNo())
+ std::optional<bool> analysisResult =
+ analysis.isSameAllocation(other, memref);
+ if (!analysisResult.has_value() || analysisResult == true)
return true;
}
return false;
@@ -129,8 +131,8 @@ namespace {
struct RemoveDeallocMemrefsContainedInRetained
: public OpRewritePattern<DeallocOp> {
RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
- AliasAnalysis &aliasAnalysis)
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+ BufferOriginAnalysis &analysis)
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
/// The passed 'memref' must not have a may-alias relation to any retained
/// memref, and at least one must-alias relation. If there is no must-aliasing
@@ -147,10 +149,11 @@ struct RemoveDeallocMemrefsContainedInRetained
// deallocated in some situations and can thus not be dropped).
bool atLeastOneMustAlias = false;
for (Value retained : deallocOp.getRetained()) {
- AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
- if (analysisResult.isMay())
+ std::optional<bool> analysisResult =
+ analysis.isSameAllocation(retained, memref);
+ if (!analysisResult.has_value())
return failure();
- if (analysisResult.isMust() || analysisResult.isPartial())
+ if (analysisResult == true)
atLeastOneMustAlias = true;
}
if (!atLeastOneMustAlias)
@@ -161,8 +164,9 @@ struct RemoveDeallocMemrefsContainedInRetained
// we can remove that operand later on.
for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
Value updatedCondition = deallocOp.getUpdatedConditions()[i];
- AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
- if (analysisResult.isMust() || analysisResult.isPartial()) {
+ std::optional<bool> analysisResult =
+ analysis.isSameAllocation(retained, memref);
+ if (analysisResult == true) {
auto disjunction = rewriter.create<arith::OrIOp>(
deallocOp.getLoc(), updatedCondition, cond);
rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
@@ -206,7 +210,7 @@ struct RemoveDeallocMemrefsContainedInRetained
}
private:
- AliasAnalysis &aliasAnalysis;
+ BufferOriginAnalysis &analysis;
};
/// Remove memrefs from the `retained` list which are guaranteed to not alias
@@ -228,15 +232,15 @@ struct RemoveDeallocMemrefsContainedInRetained
struct RemoveRetainedMemrefsGuaranteedToNotAlias
: public OpRewritePattern<DeallocOp> {
RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
- AliasAnalysis &aliasAnalysis)
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+ BufferOriginAnalysis &analysis)
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
SmallVector<Value> newRetainedMemrefs, replacements;
for (auto retainedMemref : deallocOp.getRetained()) {
- if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
+ if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
retainedMemref)) {
newRetainedMemrefs.push_back(retainedMemref);
replacements.push_back({});
@@ -264,7 +268,7 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
}
private:
- AliasAnalysis &aliasAnalysis;
+ BufferOriginAnalysis &analysis;
};
/// Split off memrefs to separate dealloc operations to reduce the number of
@@ -297,8 +301,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
struct SplitDeallocWhenNotAliasingAnyOther
: public OpRewritePattern<DeallocOp> {
SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
- AliasAnalysis &aliasAnalysis)
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+ BufferOriginAnalysis &analysis)
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
@@ -314,7 +318,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
otherMemrefs.erase(otherMemrefs.begin() + i);
// Check if `memref` can split off into a separate bufferization.dealloc.
- if (potentiallyAliasesMemref(aliasAnalysis, otherMemrefs, memref)) {
+ if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
// `memref` alias with other memrefs, do not split off.
remainingMemrefs.push_back(memref);
remainingConditions.push_back(cond);
@@ -352,7 +356,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
}
private:
- AliasAnalysis &aliasAnalysis;
+ BufferOriginAnalysis &analysis;
};
/// Check for every retained memref if a must-aliasing memref exists in the
@@ -381,8 +385,8 @@ struct SplitDeallocWhenNotAliasingAnyOther
struct RetainedMemrefAliasingAlwaysDeallocatedMemref
: public OpRewritePattern<DeallocOp> {
RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
- AliasAnalysis &aliasAnalysis)
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+ BufferOriginAnalysis &analysis)
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
@@ -396,8 +400,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
if (!matchPattern(cond, m_One()))
continue;
- AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
- if (analysisResult.isMust() || analysisResult.isPartial()) {
+ std::optional<bool> analysisResult =
+ analysis.isSameAllocation(retained, memref);
+ if (analysisResult == true) {
rewriter.replaceAllUsesWith(res, cond);
aliasesWithConstTrueMemref[i] = true;
canDropMemref = true;
@@ -411,10 +416,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
if (!extractOp)
continue;
- AliasResult extractAnalysisResult =
- aliasAnalysis.alias(retained, extractOp.getOperand());
- if (extractAnalysisResult.isMust() ||
- extractAnalysisResult.isPartial()) {
+ std::optional<bool> extractAnalysisResult =
+ analysis.isSameAllocation(retained, extractOp.getOperand());
+ if (extractAnalysisResult == true) {
rewriter.replaceAllUsesWith(res, cond);
aliasesWithConstTrueMemref[i] = true;
canDropMemref = true;
@@ -434,7 +438,7 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
}
private:
- AliasAnalysis &aliasAnalysis;
+ BufferOriginAnalysis &analysis;
};
} // namespace
@@ -452,13 +456,13 @@ struct BufferDeallocationSimplificationPass
: public bufferization::impl::BufferDeallocationSimplificationBase<
BufferDeallocationSimplificationPass> {
void runOnOperation() override {
- AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
+ BufferOriginAnalysis analysis(getOperation());
RewritePatternSet patterns(&getContext());
patterns.add<RemoveDeallocMemrefsContainedInRetained,
RemoveRetainedMemrefsGuaranteedToNotAlias,
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
- aliasAnalysis);
+ analysis);
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
if (failed(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 9a36057425f366..72f47b8b468ea6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -19,22 +19,23 @@
using namespace mlir;
using namespace mlir::bufferization;
+//===----------------------------------------------------------------------===//
+// BufferViewFlowAnalysis
+//===----------------------------------------------------------------------===//
+
/// Constructs a new alias analysis using the op provided.
BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
-/// Find all immediate and indirect dependent buffers this value could
-/// potentially have. Note that the resulting set will also contain the value
-/// provided as it is a dependent alias of itself.
-BufferViewFlowAnalysis::ValueSetT
-BufferViewFlowAnalysis::resolve(Value rootValue) const {
- ValueSetT result;
+static BufferViewFlowAnalysis::ValueSetT
+resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value) {
+ BufferViewFlowAnalysis::ValueSetT result;
SmallVector<Value, 8> queue;
- queue.push_back(rootValue);
+ queue.push_back(value);
while (!queue.empty()) {
Value currentValue = queue.pop_back_val();
if (result.insert(currentValue).second) {
- auto it = dependencies.find(currentValue);
- if (it != dependencies.end()) {
+ auto it = map.find(currentValue);
+ if (it != map.end()) {
for (Value aliasValue : it->second)
queue.push_back(aliasValue);
}
@@ -43,6 +44,19 @@ BufferViewFlowAnalysis::resolve(Value rootValue) const {
return result;
}
+/// Find all immediate and indirect dependent buffers this value could
+/// potentially have. Note that the resulting set will also contain the value
+/// provided as it is a dependent alias of itself.
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolve(Value rootValue) const {
+ return resolveValues(dependencies, rootValue);
+}
+
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolveReverse(Value rootValue) const {
+ return resolveValues(reverseDependencies, rootValue);
+}
+
/// Removes the given values from all alias sets.
void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
for (auto &entry : dependencies)
@@ -69,8 +83,10 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
void BufferViewFlowAnalysis::build(Operation *op) {
// Registers all dependencies of the given values.
auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
- for (auto [value, dep] : llvm::zip_equal(values, dependencies))
+ for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
this->dependencies[value].insert(dep);
+ this->reverseDependencies[dep].insert(value);
+ }
};
// Mark all buffer results and buffer region entry block arguments of the
@@ -188,3 +204,127 @@ bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
return terminals.contains(value);
}
+
+//===----------------------------------------------------------------------===//
+// BufferOriginAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Return "true" if the given value is the result of a memory allocation.
+static bool hasAllocateSideEffect(Value v) {
+ Operation *op = v.getDefiningOp();
+ if (!op)
+ return false;
+ return hasEffect<MemoryEffects::Allocate>(op, v);
+}
+
+/// Return "true" if the given value is a function block argument.
+static bool isFunctionArgument(Value v) {
+ auto bbArg = dyn_cast<BlockArgument>(v);
+ if (!bbArg)
+ return false;
+ Block *b = bbArg.getOwner();
+ auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
+ if (!funcOp)
+ return false;
+ return bbArg.getOwner() == &funcOp.getFunctionBody().front();
+}
+
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+ while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+ value = viewLikeOp.getViewSource();
+ return value;
+}
+
+BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {}
+
+std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
+ assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
+ assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
+
+ // Skip over all view-like ops.
+ v1 = getViewBase(v1);
+ v2 = getViewBase(v2);
+
+ // Fast path: If both buffers are the same SSA value, we can be sure that
+ // they originate from the same allocation.
+ if (v1 == v2)
+ return true;
+
+ // Compute the SSA values from which the buffers `v1` and `v2` originate.
+ SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1);
+ SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2);
+
+ // Originating buffers are "terminal" if they could not be traced back any
+ // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
+ // - function block arguments
+ // - values defined by allocation ops such as "memref.alloc"
+ // - values defined by ops that are unknown to the buffer view flow analysis
+ // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
+ SmallPtrSet<Value, 16> terminal1, terminal2;
+
+ // While gathering terminal buffers, keep track of whether all terminal
+ // buffers are newly allocated buffer or function entry arguments.
+ bool allAllocs1 = true, allAllocs2 = true;
+ bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
+
+ // Helper function that gathers terminal buffers among `origin`.
+ auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
+ SmallPtrSet<Value, 16> &terminal,
+ bool &allAllocs,
+ bool &allAllocsOrFuncEntryArgs) {
+ for (Value v : origin) {
+ if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
+ terminal.insert(v);
+ allAllocs &= hasAllocateSideEffect(v);
+ allAllocsOrFuncEntryArgs &=
+ isFunctionArgument(v) || hasAllocateSideEffect(v);
+ }
+ }
+ assert(!terminal.empty() && "expected non-empty terminal set");
+ };
+
+ // Gather terminal buffers for `v1` and `v2`.
+ gatherTerminalBuffers(origin1, terminal1, allAllocs1,
+ allAllocsOrFuncEntryArgs1);
+ gatherTerminalBuffers(origin2, terminal2, allAllocs2,
+ allAllocsOrFuncEntryArgs2);
+
+ // If both `v1` and `v2` have a single matching terminal buffer, they are
+ // guaranteed to originate from the same buffer allocation.
+ if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) &&
+ *terminal1.begin() == *terminal2.begin())
+ return true;
+
+ // At least one of the two values has multiple terminals.
+
+ // Check if there is overlap between the terminal buffers of `v1` and `v2`.
+ bool distinctTerminalSets = true;
+ for (Value v : terminal1)
+ distinctTerminalSets &= !terminal2.contains(v);
+ // If there is overlap between the terminal buffers of `v1` and `v2`, we
+ // cannot make an accurate decision without further analysis.
+ if (!distinctTerminalSets)
+ return std::nullopt;
+
+ // If `v1` originates from only allocs, and `v2` is guaranteed to originate
+ // from different allocations (that is guaranteed if `v2` originates from
+ // only distinct allocs or function entry arguments), we can be sure that
+ // `v1` and `v2` originate from different allocations. The same argument can
+ // be made when swapping `v1` and `v2`.
+ bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
+ bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
+ if (isolatedAlloc1 || isolatedAlloc2)
+ return false;
+
+ // Otherwise: We do not know whether `v1` and `v2` originate from the same
+ // allocation or not.
+ // TODO: Function arguments are currently handled conservatively. We assume
+ // that they could be the same allocation.
+ // TODO: Terminals other than allocations and function arguments are
+ // currently handled conservatively. We assume that they could be the same
+ // allocation. E.g., we currently return "nullopt" for values that originate
+ // from different "memref.get_global" ops (with different symbols).
+ return std::nullopt;
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-loops.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-loops.mlir
new file mode 100644
index 00000000000000..53b28c3aab6fd8
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-loops.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s -expand-realloc="emit-deallocs=false" -ownership-based-buffer-deallocation="private-function-dynamic-ownership=true" -canonicalize -buffer-deallocation-simplification | FileCheck %s
+
+// A function that reallocates two buffer inside of a loop. The simplification
+// pass should be able to figure out that the iter_args are always originating
+// from different allocations. IR like this one appears in the sparse compiler.
+
+// CHECK-LABEL: func private @loop_with_realloc(
+func.func private @loop_with_realloc(%lb: index, %ub: index, %step: index, %c: i1, %s1: index, %s2: index) -> (memref<?xf32>, memref<?xf32>) {
+ // CHECK-DAG: %[[false:.*]] = arith.constant false
+ // CHECK-DAG: %[[true:.*]] = arith.constant true
+
+ // CHECK: %[[m0:.*]] = memref.alloc
+ %m0 = memref.alloc(%s1) : memref<?xf32>
+ // CHECK: %[[m1:.*]] = memref.alloc
+ %m1 = memref.alloc(%s1) : memref<?xf32>
+
+ // CHECK: %[[r:.*]]:4 = scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[m0]], %[[arg1:.*]] = %[[m1]], %[[o0:.*]] = %[[false]], %[[o1:.*]] = %[[false]])
+ %r0, %r1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %m0, %arg1 = %m1) -> (memref<?xf32>, memref<?xf32>) {
+ // CHECK: %[[m2:.*]]:2 = scf.if %{{.*}} -> (memref<?xf32>, i1) {
+ // CHECK-NEXT: memref.alloc
+ // CHECK-NEXT: memref.subview
+ // CHECK-NEXT: memref.copy
+ // CHECK-NEXT: scf.yield %{{.*}}, %[[true]]
+ // CHECK-NEXT: } else {
+ // CHECK-NEXT: memref.reinterpret_cast
+ // CHECK-NEXT: scf.yield %{{.*}}, %[[false]]
+ // CHECK-NEXT: }
+ %m2 = memref.realloc %arg0(%s2) : memref<?xf32> to memref<?xf32>
+ // CHECK: %[[m3:.*]]:2 = scf.if %{{.*}} -> (memref<?xf32>, i1) {
+ // CHECK-NEXT: memref.alloc
+ // CHECK-NEXT: memref.subview
+ // CHECK-NEXT: memref.copy
+ // CHECK-NEXT: scf.yield %{{.*}}, %[[true]]
+ // CHECK-NEXT: } else {
+ // CHECK-NEXT: memref.reinterpret_cast
+ // CHECK-NEXT: scf.yield %{{.*}}, %[[false]]
+ // CHECK-NEXT: }
+ %m3 = memref.realloc %arg1(%s2) : memref<?xf32> to memref<?xf32>
+
+ // CHECK: %[[base0:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[arg0]]
+ // CHECK: %[[base1:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[arg1]]
+ // CHECK: %[[d0:.*]] = bufferization.dealloc (%[[base0]] : memref<f32>) if (%[[o0]]) retain (%[[m2]]#0 : memref<?xf32>)
+ // CHECK: %[[d1:.*]] = bufferization.dealloc (%[[base1]] : memref<f32>) if (%[[o1]]) retain (%[[m3]]#0 : memref<?xf32>)
+ // CHECK-DAG: %[[o2:.*]] = arith.ori %[[d0]], %[[m2]]#1
+ // CHECK-DAG: %[[o3:.*]] = arith.ori %[[d1]], %[[m3]]#1
+ // CHECK: scf.yield %[[m2]]#0, %[[m3]]#0, %[[o2]], %[[o3]]
+ scf.yield %m2, %m3 : memref<?xf32>, memref<?xf32>
+ }
+
+ // CHECK: %[[d2:.*]] = bufferization.dealloc (%[[m0]] : memref<?xf32>) if (%[[true]]) retain (%[[r]]#0 : memref<?xf32>)
+ // CHECK: %[[d3:.*]] = bufferization.dealloc (%[[m1]] : memref<?xf32>) if (%[[true]]) retain (%[[r]]#1 : memref<?xf32>)
+ // CHECK-DAG: %[[or0:.*]] = arith.ori %[[d2]], %[[r]]#2
+ // CHECK-DAG: %[[or1:.*]] = arith.ori %[[d3]], %[[r]]#3
+ // CHECK: return %[[r]]#0, %[[r]]#1, %[[or0]], %[[or1]]
+ return %r0, %r1 : memref<?xf32>, memref<?xf32>
+}
+
+// -----
+
+// The yielded values of the loop are swapped. Therefore, the
+// bufferization.dealloc before the func.return can now longer be split,
+// because %r0 could originate from either %m0 and %m1 (same for %r1).
+
+// CHECK-LABEL: func private @swapping_loop_with_realloc(
+func.func private @swapping_loop_with_realloc(%lb: index, %ub: index, %step: index, %c: i1, %s1: index, %s2: index) -> (memref<?xf32>, memref<?xf32>) {
+ // CHECK-DAG: %[[false:.*]] = arith.constant false
+ // CHECK-DAG: %[[true:.*]] = arith.constant true
+
+ // CHECK: %[[m0:.*]] = memref.alloc
+ %m0 = memref.alloc(%s1) : memref<?xf32>
+ // CHECK: %[[m1:.*]] = memref.alloc
+ %m1 = memref.alloc(%s1) : memref<?xf32>
+
+ // CHECK: %[[r:.*]]:4 = scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[m0]], %[[arg1:.*]] = %[[m1]], %[[o0:.*]] = %[[false]], %[[o1:.*]] = %[[false]])
+ %r0, %r1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %m0, %arg1 = %m1) -> (memref<?xf32>, memref<?xf32>) {
+ %m2 = memref.realloc %arg0(%s2) : memref<?xf32> to memref<?xf32>
+ %m3 = memref.realloc %arg1(%s2) : memref<?xf32> to memref<?xf32>
+ scf.yield %m3, %m2 : memref<?xf32>, memref<?xf32>
+ }
+
+ // CHECK: %[[base0:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[r]]#0
+ // CHECK: %[[base1:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[r]]#1
+ // CHECK: %[[d:.*]]:2 = bufferization.dealloc (%[[m0]], %[[m1]], %[[base0]], %[[base1]] : {{.*}}) if (%[[true]], %[[true]], %[[r]]#2, %[[r]]#3) retain (%[[r]]#0, %[[r]]#1 : {{.*}})
+ // CHECK: return %[[r]]#0, %[[r]]#1, %[[d]]#0, %[[d]]#1
+ return %r0, %r1 : memref<?xf32>, memref<?xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
index eee69acbe821b3..b40a17cf800bf3 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -92,15 +92,13 @@ func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>,
// CHECK-NEXT: [[ALLOC0:%.+]] = memref.alloc(
// CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloc(
// CHECK-NEXT: [[V0:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]] :
-// COM: there is only one value in the retained list because the
-// COM: RemoveRetainedMemrefsGuaranteedToNotAlias pattern also applies here and
-// COM: removes %arg1 from the list. In the second dealloc, this does not apply
-// COM: because function arguments are assumed potentially alias (even if the
-// COM: types don't exactly match).
+// COM: there is only one value in the retained lists because the
+// COM: RemoveRetainedMemrefsGuaranteedToNotAlias pattern also applies here:
+// COM: - %alloc is guaranteed to not alias with %arg1.
+// COM: - %arg2 is guaranteed to not alias with %0.
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>)
-// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]], [[V0]] : memref<2xi32>, memref<2xi32>)
-// CHECK-NEXT: [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1
-// CHECK-NEXT: return [[V2]]#0, [[V3]] :
+// CHECK-NEXT: [[V2:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]] : memref<2xi32>)
+// CHECK-NEXT: return [[V2]], [[V1]] :
// -----
More information about the Mlir-commits
mailing list