[Mlir-commits] [mlir] 2b5a020 - [mlir][bufferization][NFC] Cache definitions of read tensors

Matthias Springer llvmlistbot at llvm.org
Thu Feb 9 00:27:59 PST 2023


Author: Matthias Springer
Date: 2023-02-09T09:27:39+01:00
New Revision: 2b5a020d3e3ca2a887218776bc6e7bd930a656fa

URL: https://github.com/llvm/llvm-project/commit/2b5a020d3e3ca2a887218776bc6e7bd930a656fa
DIFF: https://github.com/llvm/llvm-project/commit/2b5a020d3e3ca2a887218776bc6e7bd930a656fa.diff

LOG: [mlir][bufferization][NFC] Cache definitions of read tensors

This is to avoid unnecessary traversals of the IR.

Differential Revision: https://reviews.llvm.org/D143408

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
    mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index f1f43e52f11dd..8a7d8f0abb5b8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -122,6 +122,13 @@ class OneShotAnalysisState : public AnalysisState {
   /// Return true if the buffer of the given tensor value is writable.
   bool isWritable(Value value) const;
 
+  /// Find the definitions of the given tensor value or retrieve them from the
+  /// cache.
+  const SetVector<Value> &findDefinitionsCached(Value value);
+
+  /// Reset cached data structures.
+  void resetCache();
+
   /// Union the alias sets of `v1` and `v2`.
   void unionAliasSets(Value v1, Value v2);
 
@@ -226,6 +233,9 @@ class OneShotAnalysisState : public AnalysisState {
   /// Check that aliasInfo for `v` exists and return a reference to it.
   EquivalenceClassRangeType getAliases(Value v) const;
 
+  /// Cache definitions of tensor values.
+  DenseMap<Value, SetVector<Value>> cachedDefinitions;
+
   /// Set of all OpResults that were decided to bufferize in-place.
   llvm::DenseSet<OpOperand *> inplaceBufferized;
 

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index 44986a60f0693..d964be91f9626 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -16,6 +16,7 @@ namespace mlir {
 namespace bufferization {
 class AnalysisState;
 struct BufferizationStatistics;
+class OneShotAnalysisState;
 struct OneShotBufferizationOptions;
 
 /// A function that matches anchor OpOperands for tensor::EmptyOp elimination.
@@ -36,7 +37,7 @@ using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
 ///   following the aliasing  OpOperand, that eventually ends at a single
 ///   tensor::EmptyOp.
 LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
-                                    bufferization::AnalysisState &state,
+                                    OneShotAnalysisState &state,
                                     AnchorMatchFn anchorMatchFunc,
                                     RewriteFn rewriteFunc);
 
@@ -44,7 +45,7 @@ LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
 /// InsertSliceOp, i.e., if it is eventually inserted into another tensor
 /// (and some other conditions are met).
 LogicalResult insertSliceAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state);
+    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state);
 
 /// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
 /// After applying this transform, the IR can be bufferized without inserting

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 76dfa2079dbd7..1579cfd04c79b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -105,7 +105,7 @@ findValidInsertionPoint(Operation *emptyTensorOp,
 /// chain, starting from the OpOperand and always following the aliasing
 /// OpOperand, that eventually ends at a single tensor::EmptyOp.
 LogicalResult mlir::bufferization::eliminateEmptyTensors(
-    RewriterBase &rewriter, Operation *op, AnalysisState &state,
+    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
     AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
   OpBuilder::InsertionGuard g(rewriter);
 
@@ -153,6 +153,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
 
       // Replace the tensor::EmptyOp.
       rewriter.replaceOp(emptyTensor.getDefiningOp(), replacement);
+      state.resetCache();
     }
 
     // Advance to the next operation.
@@ -189,7 +190,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
 ///   tensor::EmptyOp.
 template <typename OpTy>
 static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op, AnalysisState &state) {
+    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
   return eliminateEmptyTensors(
       rewriter, op, state,
       /*anchorMatchFunc=*/
@@ -224,7 +225,7 @@ static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep(
 
 LogicalResult
 mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op, AnalysisState &state) {
+    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
   if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep<
              tensor::InsertSliceOp>(rewriter, op, state)))
     return failure();

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 8a7d660f488f9..02ef3a6496f72 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -222,7 +222,7 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
 
       // If there is no preceding definition, the tensor contents are
       // undefined.
-      if (findDefinitions(opResult).empty())
+      if (findDefinitionsCached(opResult).empty())
         for (OpOperand &use : opResult.getUses())
           undefinedTensorUses.insert(&use);
     }
@@ -473,7 +473,8 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
     // In the above example, if uRead is the OpOperand of reading_op, the
     // definition is %0. Note that operations that create an alias but do not
     // bufferize to a memory write (such as ExtractSliceOp) are skipped.
-    SetVector<Value> definitions = state.findDefinitions(uRead->get());
+    const SetVector<Value> &definitions =
+        state.findDefinitionsCached(uRead->get());
     if (definitions.empty()) {
       // Fast path: No conflict if there are no definitions.
       LLVM_DEBUG(llvm::dbgs()
@@ -769,6 +770,19 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
 // Bufferization analyses.
 //===----------------------------------------------------------------------===//
 
+// Find the values that define the contents of the given value.
+const llvm::SetVector<Value> &
+OneShotAnalysisState::findDefinitionsCached(Value value) {
+  if (!cachedDefinitions.count(value)) {
+    cachedDefinitions[value] = findValueInReverseUseDefChain(
+        value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
+        /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false);
+  }
+  return cachedDefinitions[value];
+}
+
+void OneShotAnalysisState::resetCache() { cachedDefinitions.clear(); }
+
 /// Determine if `operand` can be bufferized in-place.
 static LogicalResult
 bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,


        


More information about the Mlir-commits mailing list