[Mlir-commits] [mlir] 1840d18 - [mlir][bufferization][NFC] Rename: "last-write" -> "definition"
Matthias Springer
llvmlistbot at llvm.org
Mon Jan 30 00:52:29 PST 2023
Author: Matthias Springer
Date: 2023-01-30T09:51:53+01:00
New Revision: 1840d18a10cb831e3379438bcdd9eef48ed6a39e
URL: https://github.com/llvm/llvm-project/commit/1840d18a10cb831e3379438bcdd9eef48ed6a39e
DIFF: https://github.com/llvm/llvm-project/commit/1840d18a10cb831e3379438bcdd9eef48ed6a39e.diff
LOG: [mlir][bufferization][NFC] Rename: "last-write" -> "definition"
The previous lingo was confusing. There are no writes on tensors. There are only definitions.
Also some minor cleanup and better documentation.
Differential Revision: https://reviews.llvm.org/D141790
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index ccd2711843156..061f2809b6d12 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -355,7 +355,8 @@ class AnalysisState {
/// traversed any further.
///
/// When reaching the end of a chain (BlockArgument or Value without aliasing
- /// OpOperands), also return the last Value of that chain.
+ /// OpOperands), also return the last Value of that chain if
+ /// `alwaysIncludeLeaves` is set.
///
/// Example:
///
@@ -374,20 +375,41 @@ class AnalysisState {
/// { 2, 7, 8, 5 }
///
/// If `followEquivalentOnly` is set, only equivalent OpOperands are selected.
- SetVector<Value>
- findValueInReverseUseDefChain(Value value,
- llvm::function_ref<bool(Value)> condition,
- bool followEquivalentOnly = false) const;
-
- /// Find the Values of the last preceding write of a given Value.
+ SetVector<Value> findValueInReverseUseDefChain(
+ Value value, llvm::function_ref<bool(Value)> condition,
+ bool followEquivalentOnly = false, bool alwaysIncludeLeaves = true) const;
+
+ /// Find the values that may define the contents of the given value at
+ /// runtime. A block argument is always a definition. An OpResult is a
+ /// definition if it bufferizes to memory write. If it does not bufferize to
+ /// a memory write but has aliasing operands, we continue the lookup on these
+ /// values.
+ ///
+ /// Example: %r = tensor.insert %f into %t[%c0] : tensor<?xf32>
+ /// findDefinitions(%r) = {%r} because %r bufferizes to memory write.
+ ///
+ /// Example: %r = tensor.empty() : tensor<10xf32>
+ /// findDefinitions(%r) = {} because tensor.empty does not the define the
+ /// contents of its result (i.e., it does not bufferize to a memory write)
+ /// and it has no aliasing OpOperands.
+ ///
+ /// Example:
+ /// %a = arith.constant ... : tensor<10xf32>
+ /// %b1 = tensor.insert %f into %t : tensor<50xf32>
+ /// %b2 = tensor.extract_slice %b1[0][10][1] : tensor<50xf32> tensor<10xf32>
+ /// %r = arith.select %cond, %a, %b : tensor<10xf32>
+ /// findDefinitions(%r) = {%a, %b1}. %r and %b2 are skipped (lookup continues
+ /// in the operands) because their defining ops do not define the contents of
+ /// the tensor.
///
- /// Note: Unknown ops are handled conservatively and assumed to be writes.
- /// Furthermore, BlockArguments are also assumed to be writes. There is no
- /// analysis across block boundaries.
+ /// Note: OpResults of unknown ops are handled conservatively and assumed to
+ /// be definitions.
///
/// Note: When reaching an end of the reverse SSA use-def chain, that value
- /// is returned regardless of whether it is a memory write or not.
- SetVector<Value> findLastPrecedingWrite(Value value) const;
+ /// is included regardless of whether it is a definition or not unless
+ /// `alwaysIncludeLeaves` is unset.
+ SetVector<Value> findDefinitions(Value value,
+ bool alwaysIncludeLeaves = true) const;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
virtual bool isInPlace(OpOperand &opOperand) const;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 8abb9c35d2682..c24653d18abaa 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -444,7 +444,7 @@ bool AnalysisState::isValueRead(Value value) const {
// further.
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
Value value, llvm::function_ref<bool(Value)> condition,
- bool followEquivalentOnly) const {
+ bool followEquivalentOnly, bool alwaysIncludeLeaves) const {
llvm::SetVector<Value> result, workingSet;
workingSet.insert(value);
@@ -469,7 +469,8 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
(followEquivalentOnly &&
bufferizableOp.bufferRelation(opResult, *this) !=
BufferRelation::Equivalent)) {
- result.insert(value);
+ if (alwaysIncludeLeaves)
+ result.insert(value);
continue;
}
@@ -480,11 +481,12 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
return result;
}
-// Find the Values of the last preceding write of a given Value.
+// Find the values that define the contents of the given value.
llvm::SetVector<Value>
-AnalysisState::findLastPrecedingWrite(Value value) const {
+AnalysisState::findDefinitions(Value value, bool alwaysIncludeLeaves) const {
return findValueInReverseUseDefChain(
- value, [&](Value v) { return this->bufferizesToMemoryWrite(v); });
+ value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
+ /*followEquivalentOnly=*/false, alwaysIncludeLeaves);
}
AnalysisState::AnalysisState(const BufferizationOptions &options)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 2e48067140955..8570352b52b4b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -270,16 +270,9 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
if (!opResult.getType().isa<TensorType>())
continue;
- // If there is no preceding memory write, the tensor contents are
+ // If there is no preceding definition, the tensor contents are
// undefined.
- // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
- // use-def chain, it returns that value, regardless of whether it is a
- // memory write or not.
- SetVector<Value> lastWrites = findLastPrecedingWrite(opResult);
- bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) {
- return this->bufferizesToMemoryWrite(lastWrite);
- });
- if (isUndefined)
+ if (findDefinitions(opResult, /*alwaysIncludeLeaves=*/false).empty())
for (OpOperand &use : opResult.getUses())
undefinedTensorUses.insert(&use);
}
@@ -471,7 +464,7 @@ bool canUseOpDominance(const DenseSet<OpOperand *> &usesRead,
/// Annotate IR with details about the detected RaW conflict.
static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
- Value lastWrite) {
+ Value definition) {
static uint64_t counter = 0;
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
@@ -489,16 +482,15 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
readingOp->setAttr(readAttr, b.getUnitAttr());
- if (auto opResult = lastWrite.dyn_cast<OpResult>()) {
- std::string lastWriteAttr = id + "[LAST-WRITE: result " +
- std::to_string(opResult.getResultNumber()) +
- "]";
- opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr());
+ if (auto opResult = definition.dyn_cast<OpResult>()) {
+ std::string defAttr =
+ id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]";
+ opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr());
} else {
- auto bbArg = lastWrite.cast<BlockArgument>();
- std::string lastWriteAttr =
- id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
- bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr());
+ auto bbArg = definition.cast<BlockArgument>();
+ std::string defAttr =
+ id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
+ bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr());
}
}
@@ -507,8 +499,8 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
/// all given writes bufferize inplace.
///
/// A conflict is: According to SSA use-def chains, a read R is supposed to read
-/// the result of a write W1. But because of bufferization decisions, R actually
-/// reads another write W2.
+/// the result of a definition W1. But because of bufferization decisions, R
+/// actually reads another definition W2.
static bool hasReadAfterWriteInterference(
const DenseSet<OpOperand *> &usesRead,
const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
@@ -529,10 +521,10 @@ static bool hasReadAfterWriteInterference(
// %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
// %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
//
- // In the above example, if uRead is the OpOperand of reading_op, lastWrite
- // is %0. Note that operations that create an alias but do not write (such
- // as ExtractSliceOp) are skipped.
- SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
+ // In the above example, if uRead is the OpOperand of reading_op, the
+ // definition is %0. Note that operations that create an alias but do not
+ // bufferize to a memory write (such as ExtractSliceOp) are skipped.
+ SetVector<Value> definitions = state.findDefinitions(uRead->get());
// Look for conflicting memory writes. Potential conflicts are writes to an
// alias that have been decided to bufferize inplace.
@@ -611,31 +603,30 @@ static bool hasReadAfterWriteInterference(
}
}
- // Check all possible last writes.
- for (Value lastWrite : lastWrites) {
- LLVM_DEBUG(llvm::dbgs() << " * lastWrite = " << lastWrite << "\n");
+ // Check all possible definitions.
+ for (Value definition : definitions) {
+ LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");
- // No conflict if the conflicting write happens before the last
- // write.
- if (Operation *writingOp = lastWrite.getDefiningOp()) {
+ // No conflict if the conflicting write happens before the definition.
+ if (Operation *writingOp = definition.getDefiningOp()) {
if (happensBefore(conflictingWritingOp, writingOp, domInfo)) {
// conflictingWritingOp happens before writingOp. No conflict.
LLVM_DEBUG(llvm::dbgs()
- << " no conflict: write happens before last write\n");
+ << " no conflict: write happens before definition\n");
continue;
}
// No conflict if conflictingWritingOp is contained in writingOp.
if (writingOp->isProperAncestor(conflictingWritingOp)) {
LLVM_DEBUG(
llvm::dbgs()
- << " no conflict: write is contained in last write\n");
+ << " no conflict: write is contained in definition\n");
continue;
}
} else {
- auto bbArg = lastWrite.cast<BlockArgument>();
+ auto bbArg = definition.cast<BlockArgument>();
Block *block = bbArg.getOwner();
if (!block->findAncestorOpInBlock(*conflictingWritingOp)) {
- LLVM_DEBUG(llvm::dbgs() << " no conflict: last write is bbArg "
+ LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "
"and write happens outside of block\n");
// conflictingWritingOp happens outside of the block. No
// conflict.
@@ -643,20 +634,20 @@ static bool hasReadAfterWriteInterference(
}
}
- // No conflict if the conflicting write and the last write are the same
+ // No conflict if the conflicting write and the definition are the same
// use.
SmallVector<OpResult> aliasingOpResult =
state.getAliasingOpResult(*uConflictingWrite);
- if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite) {
+ if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == definition) {
LLVM_DEBUG(llvm::dbgs()
- << " no conflict: last write and write are same\n");
+ << " no conflict: definition and write are same\n");
continue;
}
// All requirements are met. Conflict found!
if (options.printConflicts)
- annotateConflict(uRead, uConflictingWrite, lastWrite);
+ annotateConflict(uRead, uConflictingWrite, definition);
LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n");
return true;
}
@@ -734,8 +725,8 @@ static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
/// conflict because:
/// * According to SSA use-def chains, we expect to read the result of %1.
/// * However, adding an alias {%0, %t} would mean that the second
-/// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
-/// would no longer be reading the result of %1.
+/// TransferWriteOp overwrites the result of the first one. Therefore, the
+/// TransferReadOp would no longer be reading the result of %1.
///
/// If `checkConsistencyOnly` is true, this function checks if there is a
/// read-after-write conflict without bufferizing `operand` inplace. This would
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 5d0d44f3c53d6..2f578445de1dd 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -712,7 +712,7 @@ static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
// In the above example:
// uRead = OpOperand 0 (%1) of vector.transfer_read
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
- // lastWrite = %1
+ // definition = %1
//
// This is not a conflict because the InsertSliceOp overwrites the
// memory segment of %1 with the exact same data. (Effectively, there
More information about the Mlir-commits
mailing list