[Mlir-commits] [mlir] 2b5a020 - [mlir][bufferization][NFC] Cache definitions of read tensors
Matthias Springer
llvmlistbot at llvm.org
Thu Feb 9 00:27:59 PST 2023
Author: Matthias Springer
Date: 2023-02-09T09:27:39+01:00
New Revision: 2b5a020d3e3ca2a887218776bc6e7bd930a656fa
URL: https://github.com/llvm/llvm-project/commit/2b5a020d3e3ca2a887218776bc6e7bd930a656fa
DIFF: https://github.com/llvm/llvm-project/commit/2b5a020d3e3ca2a887218776bc6e7bd930a656fa.diff
LOG: [mlir][bufferization][NFC] Cache definitions of read tensors
This is to avoid unnecessary traversals of the IR.
Differential Revision: https://reviews.llvm.org/D143408
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index f1f43e52f11dd..8a7d8f0abb5b8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -122,6 +122,13 @@ class OneShotAnalysisState : public AnalysisState {
/// Return true if the buffer of the given tensor value is writable.
bool isWritable(Value value) const;
+ /// Find the definitions of the given tensor value or retrieve them from the
+ /// cache.
+ const SetVector<Value> &findDefinitionsCached(Value value);
+
+ /// Reset cached data structures.
+ void resetCache();
+
/// Union the alias sets of `v1` and `v2`.
void unionAliasSets(Value v1, Value v2);
@@ -226,6 +233,9 @@ class OneShotAnalysisState : public AnalysisState {
/// Check that aliasInfo for `v` exists and return a reference to it.
EquivalenceClassRangeType getAliases(Value v) const;
+ /// Cache definitions of tensor values.
+ DenseMap<Value, SetVector<Value>> cachedDefinitions;
+
/// Set of all OpResults that were decided to bufferize in-place.
llvm::DenseSet<OpOperand *> inplaceBufferized;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index 44986a60f0693..d964be91f9626 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -16,6 +16,7 @@ namespace mlir {
namespace bufferization {
class AnalysisState;
struct BufferizationStatistics;
+class OneShotAnalysisState;
struct OneShotBufferizationOptions;
/// A function that matches anchor OpOperands for tensor::EmptyOp elimination.
@@ -36,7 +37,7 @@ using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
/// following the aliasing OpOperand, that eventually ends at a single
/// tensor::EmptyOp.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
- bufferization::AnalysisState &state,
+ OneShotAnalysisState &state,
AnchorMatchFn anchorMatchFunc,
RewriteFn rewriteFunc);
@@ -44,7 +45,7 @@ LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
/// (and some other conditions are met).
LogicalResult insertSliceAnchoredEmptyTensorEliminationStep(
- RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state);
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state);
/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
/// After applying this transform, the IR can be bufferized without inserting
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 76dfa2079dbd7..1579cfd04c79b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -105,7 +105,7 @@ findValidInsertionPoint(Operation *emptyTensorOp,
/// chain, starting from the OpOperand and always following the aliasing
/// OpOperand, that eventually ends at a single tensor::EmptyOp.
LogicalResult mlir::bufferization::eliminateEmptyTensors(
- RewriterBase &rewriter, Operation *op, AnalysisState &state,
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
OpBuilder::InsertionGuard g(rewriter);
@@ -153,6 +153,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
// Replace the tensor::EmptyOp.
rewriter.replaceOp(emptyTensor.getDefiningOp(), replacement);
+ state.resetCache();
}
// Advance to the next operation.
@@ -189,7 +190,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
/// tensor::EmptyOp.
template <typename OpTy>
static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep(
- RewriterBase &rewriter, Operation *op, AnalysisState &state) {
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
return eliminateEmptyTensors(
rewriter, op, state,
/*anchorMatchFunc=*/
@@ -224,7 +225,7 @@ static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep(
LogicalResult
mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
- RewriterBase &rewriter, Operation *op, AnalysisState &state) {
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep<
tensor::InsertSliceOp>(rewriter, op, state)))
return failure();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 8a7d660f488f9..02ef3a6496f72 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -222,7 +222,7 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
// If there is no preceding definition, the tensor contents are
// undefined.
- if (findDefinitions(opResult).empty())
+ if (findDefinitionsCached(opResult).empty())
for (OpOperand &use : opResult.getUses())
undefinedTensorUses.insert(&use);
}
@@ -473,7 +473,8 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// In the above example, if uRead is the OpOperand of reading_op, the
// definition is %0. Note that operations that create an alias but do not
// bufferize to a memory write (such as ExtractSliceOp) are skipped.
- SetVector<Value> definitions = state.findDefinitions(uRead->get());
+ const SetVector<Value> &definitions =
+ state.findDefinitionsCached(uRead->get());
if (definitions.empty()) {
// Fast path: No conflict if there are no definitions.
LLVM_DEBUG(llvm::dbgs()
@@ -769,6 +770,19 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
// Bufferization analyses.
//===----------------------------------------------------------------------===//
+// Find the values that define the contents of the given value.
+const llvm::SetVector<Value> &
+OneShotAnalysisState::findDefinitionsCached(Value value) {
+ if (!cachedDefinitions.count(value)) {
+ cachedDefinitions[value] = findValueInReverseUseDefChain(
+ value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
+ /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false);
+ }
+ return cachedDefinitions[value];
+}
+
+void OneShotAnalysisState::resetCache() { cachedDefinitions.clear(); }
+
/// Determine if `operand` can be bufferized in-place.
static LogicalResult
bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
More information about the Mlir-commits
mailing list