[Mlir-commits] [mlir] 37317f5 - [mlir][linalg][bufferize] Decouple BufferizationAliasInfo
Matthias Springer
llvmlistbot at llvm.org
Thu Nov 4 19:51:22 PDT 2021
Author: Matthias Springer
Date: 2021-11-05T11:41:00+09:00
New Revision: 37317f5bd21297af49b1f3968e0b44cb3596f653
URL: https://github.com/llvm/llvm-project/commit/37317f5bd21297af49b1f3968e0b44cb3596f653
DIFF: https://github.com/llvm/llvm-project/commit/37317f5bd21297af49b1f3968e0b44cb3596f653.diff
LOG: [mlir][linalg][bufferize] Decouple BufferizationAliasInfo
Move dialect-specific and analysis-specific function out of BufferizationAliasInfo. BufferizationAliasInfo's only job now is to keep track of aliases.
This is in preparation of futher decoupling ComprehensiveBufferize from various dialects.
Differential Revision: https://reviews.llvm.org/D112992
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
index 94cb52b4bca5..4decaee4238b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
@@ -48,14 +48,6 @@ class BufferizationAliasInfo {
/// `alias`. Additionally, merge their equivalence classes.
void insertNewBufferEquivalence(Value newValue, Value alias);
- /// Return true if, under current bufferization decisions, the buffer of
- /// `value` is not writable.
- bool aliasesNonWritableBuffer(Value value) const;
-
- /// Return true if the buffer to which `operand` would bufferize is equivalent
- /// to some buffer write.
- bool aliasesInPlaceWrite(Value v) const;
-
/// Set the inPlace bufferization spec to true.
/// Merge result's and operand's aliasing sets and iterate to a fixed point.
void bufferizeInPlace(OpResult result, OpOperand &operand);
@@ -63,23 +55,6 @@ class BufferizationAliasInfo {
/// Set the inPlace bufferization spec to false.
void bufferizeOutOfPlace(OpResult result);
- /// Return true if `value` has an ExtractSliceOp matching the given
- /// InsertSliceOp in its reverse SSA use-def chain.
- bool hasMatchingExtractSliceOp(Value value,
- tensor::InsertSliceOp insertOp) const;
-
- /// Return true if bufferizing `opOperand` inplace with `opResult` would
- /// create a write to a non-writable buffer.
- bool wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
- OpResult opResult) const;
-
- /// Assume that result bufferizes in-place with one of the operation's
- /// operands. Return true if it is possible to find an inplace write W that
- /// creates a conflict.
- bool
- wouldCreateReadAfterWriteInterference(OpOperand &operand, OpResult result,
- const DominanceInfo &domInfo) const;
-
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
bool areEquivalentBufferizedValues(Value v1, Value v2) const {
// Return `false` if we have no information about `v1` or `v2`.
@@ -91,14 +66,13 @@ class BufferizationAliasInfo {
equivalentInfo.getLeaderValue(v2);
}
- /// Return true if the source of an `insertSliceOp` bufferizes to an
- /// equivalent ExtractSliceOp.
- bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
- tensor::InsertSliceOp insertSliceOp) const;
-
/// Apply `fun` to all the members of the equivalence class of `v`.
void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
+ /// Apply `fun` to all aliases of `v`.
+ void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
+
+ // TODO: Move these out of BufferizationAliasInfo.
/// Return true if the value is known to bufferize to writable memory.
bool bufferizesToWritableMemory(Value v) const;
@@ -128,22 +102,6 @@ class BufferizationAliasInfo {
/// Check that aliasInfo for `v` exists and return a reference to it.
EquivalenceClassRangeType getAliases(Value v) const;
- /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
- /// equivalent operand / result and same offset/sizes/strides specification).
- ///
- /// This is one particular type of relationship between ops on tensors that
- /// reduce to an equivalence on buffers. This should be generalized and
- /// exposed as interfaces on the proper types.
- bool areEquivalentExtractSliceOps(tensor::ExtractSliceOp st,
- tensor::InsertSliceOp sti) const;
-
- /// Given sets of uses and writes, return true if there is a RaW conflict
- /// under the assumption that all given reads/writes alias the same buffer and
- /// that all given writes bufferize inplace.
- bool hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
- const DenseSet<OpOperand *> &usesWrite,
- const DominanceInfo &domInfo) const;
-
/// Set of tensors that are known to bufferize to writable memory.
llvm::DenseSet<Value> bufferizeToWritableMemory;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 63145b157d55..28e432a88ad3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -508,6 +508,24 @@ static BufferRelation bufferRelation(OpOperand &opOperand) {
// Bufferization-specific alias analysis.
//===----------------------------------------------------------------------===//
+/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
+/// equivalent operand / result and same offset/sizes/strides specification).
+///
+/// This is one particular type of relationship between ops on tensors that
+/// reduce to an equivalence on buffers. This should be generalized and
+/// exposed as interfaces on the proper types.
+static bool
+areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
+ ExtractSliceOp st, InsertSliceOp sti) {
+ if (!st || !sti)
+ return false;
+ if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
+ return false;
+ if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
+ return false;
+ return true;
+}
+
/// Return true if opOperand has been decided to bufferize in-place.
static bool isInplaceMemoryWrite(OpOperand &opOperand) {
// Ops that do not bufferize to a memory write, cannot be write in-place.
@@ -567,24 +585,27 @@ void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
/// Return true if, under current bufferization decisions, the buffer of `value`
/// is not writable.
-bool BufferizationAliasInfo::aliasesNonWritableBuffer(Value value) const {
+static bool aliasesNonWritableBuffer(Value value,
+ const BufferizationAliasInfo &aliasInfo) {
LDBG("----Start aliasesNonWritableBuffer\n");
- for (Value v : getAliases(value)) {
+ bool foundNonWritableBuffer = false;
+ aliasInfo.applyOnAliases(value, [&](Value v) {
LDBG("-----------examine: " << printValueInfo(v) << '\n');
- if (bufferizesToWritableMemory(v)) {
+ if (aliasInfo.bufferizesToWritableMemory(v)) {
LDBG("-----------Value is known to be writable -> skip: "
<< printValueInfo(v) << '\n');
- continue;
+ return;
}
if (auto bbArg = v.dyn_cast<BlockArgument>()) {
if (getInPlace(bbArg) == InPlaceSpec::True) {
LDBG("-----------bbArg is writable -> skip: " << printValueInfo(bbArg)
<< '\n');
- continue;
+ return;
}
LDBG("-----------notWritable bbArg\n");
- return true;
+ foundNonWritableBuffer = true;
+ return;
}
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(v.getDefiningOp());
@@ -592,11 +613,15 @@ bool BufferizationAliasInfo::aliasesNonWritableBuffer(Value value) const {
// Unknown ops are treated conservatively: Assume that it is illegal to
// write to their OpResults in-place.
LDBG("-----------notWritable op\n");
- return true;
+ foundNonWritableBuffer = true;
+ return;
}
- }
- LDBG("---->value is writable\n");
- return false;
+ });
+
+ if (!foundNonWritableBuffer)
+ LDBG("---->value is writable\n");
+
+ return foundNonWritableBuffer;
}
bool BufferizationAliasInfo::bufferizesToWritableMemory(Value v) const {
@@ -610,20 +635,26 @@ void BufferizationAliasInfo::setBufferizesToWritableMemory(Value v) {
/// Return true if the buffer to which `operand` would bufferize is equivalent
/// to some buffer write.
-bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const {
+static bool aliasesInPlaceWrite(Value value,
+ const BufferizationAliasInfo &aliasInfo) {
LDBG("----Start aliasesInPlaceWrite\n");
LDBG("-------for : " << printValueInfo(value) << '\n');
- for (Value v : getAliases(value)) {
+ bool foundInplaceWrite = false;
+ aliasInfo.applyOnAliases(value, [&](Value v) {
for (auto &use : v.getUses()) {
if (isInplaceMemoryWrite(use)) {
LDBG("-----------wants to bufferize to inPlace write: "
<< printOperationInfo(use.getOwner()) << '\n');
- return true;
+ foundInplaceWrite = true;
+ return;
}
}
- }
- LDBG("----------->does not alias an inplace write\n");
- return false;
+ });
+
+ if (!foundInplaceWrite)
+ LDBG("----------->does not alias an inplace write\n");
+
+ return foundInplaceWrite;
}
/// Set the inPlace bufferization spec to true.
@@ -731,11 +762,11 @@ static Value findLastPrecedingWrite(Value value) {
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
-bool BufferizationAliasInfo::hasMatchingExtractSliceOp(
- Value value, InsertSliceOp insertOp) const {
+static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
+ Value value, InsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
- if (areEquivalentExtractSliceOps(extractOp, insertOp))
+ if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
return true;
return false;
};
@@ -766,10 +797,11 @@ static bool happensBefore(Operation *a, Operation *b,
/// A conflict is: According to SSA use-def chains, a read R is supposed to read
/// the result of a write W1. But because of bufferization decisions, R actually
/// reads another write W2.
-bool BufferizationAliasInfo::hasReadAfterWriteInterference(
- const DenseSet<OpOperand *> &usesRead,
- const DenseSet<OpOperand *> &usesWrite,
- const DominanceInfo &domInfo) const {
+static bool
+hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
+ const DenseSet<OpOperand *> &usesWrite,
+ const DominanceInfo &domInfo,
+ const BufferizationAliasInfo &aliasInfo) {
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
@@ -850,7 +882,8 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(uConflictingWrite->get(), insertSliceOp))
+ hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
+ insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
@@ -867,7 +900,7 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(uRead->get(), insertSliceOp))
+ hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
@@ -910,8 +943,9 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
/// * However, adding an alias {%0, %t} would mean that the second
/// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
/// would no longer be reading the result of %1.
-bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
- OpOperand &operand, OpResult result, const DominanceInfo &domInfo) const {
+bool wouldCreateReadAfterWriteInterference(
+ OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
+ const BufferizationAliasInfo &aliasInfo) {
#ifndef NDEBUG
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
assert(llvm::find(opOperands, &operand) != opOperands.end() &&
@@ -920,20 +954,22 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
// Helper function to iterate on aliases of `root` and capture the reads.
auto getAliasingReads = [&](DenseSet<OpOperand *> &res, Value root) {
- for (Value alias : getAliases(root))
+ aliasInfo.applyOnAliases(root, [&](Value alias) {
for (auto &use : alias.getUses())
// Read to a value that aliases root.
if (bufferizesToMemoryRead(use))
res.insert(&use);
+ });
};
// Helper function to iterate on aliases of `root` and capture the writes.
auto getAliasingInplaceWrites = [&](DenseSet<OpOperand *> &res, Value root) {
- for (Value alias : getAliases(root))
+ aliasInfo.applyOnAliases(root, [&](Value alias) {
for (auto &use : alias.getUses())
// Inplace write to a value that aliases root.
if (isInplaceMemoryWrite(use))
res.insert(&use);
+ });
};
// Collect reads and writes of all aliases of OpOperand and OpResult.
@@ -945,13 +981,14 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
if (bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);
- return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo);
+ return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, aliasInfo);
}
/// Return true if bufferizing `opOperand` inplace with `opResult` would create
/// a write to a non-writable buffer.
-bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer(
- OpOperand &opOperand, OpResult opResult) const {
+static bool
+wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
+ const BufferizationAliasInfo &aliasInfo) {
#ifndef NDEBUG
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
assert(llvm::find(opOperands, &opOperand) != opOperands.end() &&
@@ -961,15 +998,15 @@ bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer(
// Certain buffers are not writeable:
// 1. A function bbArg that is not inplaceable or
// 2. A constant op.
- assert(!aliasesNonWritableBuffer(opResult) &&
+ assert(!aliasesNonWritableBuffer(opResult, aliasInfo) &&
"expected that opResult does not alias non-writable buffer");
- bool nonWritable = aliasesNonWritableBuffer(opOperand.get());
+ bool nonWritable = aliasesNonWritableBuffer(opOperand.get(), aliasInfo);
if (!nonWritable)
return false;
// This is a problem only if the buffer is written to via some alias.
- bool hasWrite = aliasesInPlaceWrite(opResult) ||
- aliasesInPlaceWrite(opOperand.get()) ||
+ bool hasWrite = aliasesInPlaceWrite(opResult, aliasInfo) ||
+ aliasesInPlaceWrite(opOperand.get(), aliasInfo) ||
bufferizesToMemoryWrite(opOperand);
if (!hasWrite)
return false;
@@ -978,28 +1015,6 @@ bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer(
return true;
}
-/// Return true if the source of a `insertSliceOp` bufferizes to an
-/// equivalent ExtractSliceOp that bufferizes inplace.
-bool BufferizationAliasInfo::isSourceEquivalentToAMatchingInplaceExtractSliceOp(
- InsertSliceOp insertSliceOp) const {
- LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp
- << '\n');
- auto leaderIt = equivalentInfo.findLeader(insertSliceOp.source());
- for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
- ++mit) {
- auto extractSliceOp =
- dyn_cast_or_null<ExtractSliceOp>(mit->getDefiningOp());
- if (extractSliceOp &&
- areEquivalentExtractSliceOps(extractSliceOp, insertSliceOp) &&
- getInPlace(extractSliceOp.result()) == InPlaceSpec::True) {
- LDBG("\tfound: " << *mit->getDefiningOp() << '\n');
- return true;
- }
- }
- LDBG("\tnot equivalent\n");
- return false;
-}
-
/// Apply `fun` to all the members of the equivalence class of `v`.
void BufferizationAliasInfo::applyOnEquivalenceClass(
Value v, function_ref<void(Value)> fun) const {
@@ -1010,6 +1025,15 @@ void BufferizationAliasInfo::applyOnEquivalenceClass(
}
}
+/// Apply `fun` to all aliases of `v`.
+void BufferizationAliasInfo::applyOnAliases(
+ Value v, function_ref<void(Value)> fun) const {
+ auto leaderIt = aliasInfo.findLeader(v);
+ for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
+ fun(*mit);
+ }
+}
+
void BufferizationAliasInfo::printAliases(raw_ostream &os) const {
os << "\n/===================== AliasInfo =====================\n";
for (auto it = aliasInfo.begin(), eit = aliasInfo.end(); it != eit; ++it) {
@@ -1066,20 +1090,6 @@ void BufferizationAliasInfo::dumpEquivalences() const {
printEquivalences(llvm::errs());
}
-/// This is one particular type of relationship between ops on tensors that
-/// reduce to an equivalence on buffers. This should be generalized and exposed
-/// as interfaces on the proper types.
-bool BufferizationAliasInfo::areEquivalentExtractSliceOps(
- ExtractSliceOp st, InsertSliceOp sti) const {
- if (!st || !sti)
- return false;
- if (!equivalentInfo.isEquivalent(st.source(), sti.dest()))
- return false;
- if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
- return false;
- return true;
-}
-
//===----------------------------------------------------------------------===//
// Forward declarations.
//===----------------------------------------------------------------------===//
@@ -1475,8 +1485,9 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
<< printValueInfo(result) << '\n');
bool foundInterference =
- aliasInfo.wouldCreateWriteToNonWritableBuffer(operand, result) ||
- aliasInfo.wouldCreateReadAfterWriteInterference(operand, result, domInfo);
+ wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo) ||
+ wouldCreateReadAfterWriteInterference(operand, result, domInfo,
+ aliasInfo);
if (foundInterference)
aliasInfo.bufferizeOutOfPlace(result);
@@ -3276,6 +3287,30 @@ struct ExtractOpInterface
}
};
+/// Return true if the source of a `insertSliceOp` bufferizes to an
+/// equivalent ExtractSliceOp that bufferizes inplace.
+static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
+ const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
+ LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp
+ << '\n');
+ bool foundOp = false;
+ aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
+ auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
+ if (extractSliceOp &&
+ areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
+ insertSliceOp) &&
+ getInPlace(extractSliceOp.result()) == InPlaceSpec::True) {
+ LDBG("\tfound: " << extractSliceOp.getOperation() << '\n');
+ foundOp = true;
+ }
+ });
+
+ if (!foundOp)
+ LDBG("\tnot equivalent\n");
+
+ return foundOp;
+}
+
struct InsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
tensor::InsertSliceOp> {
@@ -3345,8 +3380,8 @@ struct InsertSliceOpInterface
// cloned and the clone needs to be updated.
auto inPlace = getInPlace(insertSliceOp->getResult(0));
// TODO: Is this necessary?
- if (!aliasInfo.isSourceEquivalentToAMatchingInplaceExtractSliceOp(
- insertSliceOp) ||
+ if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo,
+ insertSliceOp) ||
inPlace != InPlaceSpec::True) {
LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
<< " -> copy\n");
More information about the Mlir-commits
mailing list