[Mlir-commits] [mlir] 9312b4f - [mlir][bufferization] Cache enclosing repetitive region

Martin Erhart llvmlistbot at llvm.org
Sat Jul 8 02:33:35 PDT 2023


Author: Martin Erhart
Date: 2023-07-08T09:30:41Z
New Revision: 9312b4f90fcd9e1bf0186b66912c3b83c2d35f51

URL: https://github.com/llvm/llvm-project/commit/9312b4f90fcd9e1bf0186b66912c3b83c2d35f51
DIFF: https://github.com/llvm/llvm-project/commit/9312b4f90fcd9e1bf0186b66912c3b83c2d35f51.diff

LOG: [mlir][bufferization] Cache enclosing repetitive region

The `getEnclosingRepetitiveRegion` functions walk the ancestor regions everytime which can be expensive especially when there are multiple regions inbetween. This commit adds a cache to the bufferization analysis to remember the result of the walk.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D154710

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 45d705c444a7e8..d1faaf56c6afc4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMapInfoVariant.h"
 #include "llvm/ADT/SetVector.h"
 #include <optional>
 
@@ -546,6 +547,21 @@ class AnalysisState {
 
   TypeID getType() const { return type; }
 
+  /// Return the closest enclosing repetitive region around the given op.
+  Region *getEnclosingRepetitiveRegion(Operation *op,
+                                       const BufferizationOptions &options);
+
+  /// Return the closest enclosing repetitive region around the place where the
+  /// given value is defined.
+  Region *getEnclosingRepetitiveRegion(Value value,
+                                       const BufferizationOptions &options);
+
+  /// Return the closest enclosing repetitive region around the given block.
+  Region *getEnclosingRepetitiveRegion(Block *block,
+                                       const BufferizationOptions &options);
+
+  virtual void resetCache();
+
 protected:
   AnalysisState(const BufferizationOptions &options, TypeID type);
 
@@ -555,6 +571,10 @@ class AnalysisState {
 
   /// The type of analysis.
   TypeID type;
+
+  /// Cache containing closest ancestor repetitive Region.
+  DenseMap<std::variant<Operation *, Block *, Region *, Value>, Region *>
+      enclosingRepetitiveRegionCache;
 };
 
 /// Create an AllocTensorOp for the given shaped value (memref or tensor).
@@ -652,19 +672,6 @@ getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
 /// owner of the block. In case of an OpResult that is the defining op.
 Operation *getOwnerOfValue(Value value);
 
-/// Return the closest enclosing repetitive region around the given op.
-Region *getEnclosingRepetitiveRegion(Operation *op,
-                                     const BufferizationOptions &options);
-
-/// Return the closest enclosing repetitive region around the place where the
-/// given value is defined.
-Region *getEnclosingRepetitiveRegion(Value value,
-                                     const BufferizationOptions &options);
-
-/// Return the closest enclosing repetitive region around the given block.
-Region *getEnclosingRepetitiveRegion(Block *block,
-                                     const BufferizationOptions &options);
-
 /// Assuming that the given region is repetitive, find the next enclosing
 /// repetitive region.
 Region *getNextEnclosingRepetitiveRegion(Region *region,

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 4fd3da1548c410..585c7ca92c7189 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -130,7 +130,7 @@ class OneShotAnalysisState : public AnalysisState {
   const SetVector<Value> &findDefinitionsCached(Value value);
 
   /// Reset cached data structures.
-  void resetCache();
+  void resetCache() override;
 
   /// Union the alias sets of `v1` and `v2`.
   void unionAliasSets(Value v1, Value v2);

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 9a6f85a625003f..d9c334983ad814 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -50,36 +50,64 @@ static bool isRepetitiveRegion(Region *region,
   return false;
 }
 
-Region *bufferization::getEnclosingRepetitiveRegion(
+Region *AnalysisState::getEnclosingRepetitiveRegion(
     Operation *op, const BufferizationOptions &options) {
   if (!op->getBlock())
     return nullptr;
-  return getEnclosingRepetitiveRegion(op->getBlock(), options);
+  if (auto iter = enclosingRepetitiveRegionCache.find_as(op);
+      iter != enclosingRepetitiveRegionCache.end())
+    return iter->second;
+  return enclosingRepetitiveRegionCache[op] =
+             getEnclosingRepetitiveRegion(op->getBlock(), options);
 }
 
-Region *bufferization::getEnclosingRepetitiveRegion(
+Region *AnalysisState::getEnclosingRepetitiveRegion(
     Value value, const BufferizationOptions &options) {
+  if (auto iter = enclosingRepetitiveRegionCache.find_as(value);
+      iter != enclosingRepetitiveRegionCache.end())
+    return iter->second;
+
   Region *region = value.getParentRegion();
+  // Collect all visited regions since we only know the repetitive region we
+  // want to map it to later on
+  SmallVector<Region *> visitedRegions;
   while (region) {
+    visitedRegions.push_back(region);
     if (isRepetitiveRegion(region, options))
-      return region;
+      break;
     region = region->getParentRegion();
   }
-  return nullptr;
+  enclosingRepetitiveRegionCache[value] = region;
+  for (Region *r : visitedRegions)
+    enclosingRepetitiveRegionCache[r] = region;
+  return region;
 }
 
-Region *bufferization::getEnclosingRepetitiveRegion(
+Region *AnalysisState::getEnclosingRepetitiveRegion(
     Block *block, const BufferizationOptions &options) {
+  if (auto iter = enclosingRepetitiveRegionCache.find_as(block);
+      iter != enclosingRepetitiveRegionCache.end())
+    return iter->second;
+
   Region *region = block->getParent();
   Operation *op = nullptr;
+  // Collect all visited regions since we only know the repetitive region we
+  // want to map it to later on
+  SmallVector<Region *> visitedRegions;
   do {
     op = region->getParentOp();
     if (isRepetitiveRegion(region, options))
-      return region;
+      break;
   } while ((region = op->getParentRegion()));
-  return nullptr;
+
+  enclosingRepetitiveRegionCache[block] = region;
+  for (Region *r : visitedRegions)
+    enclosingRepetitiveRegionCache[r] = region;
+  return region;
 }
 
+void AnalysisState::resetCache() { enclosingRepetitiveRegionCache.clear(); }
+
 Region *bufferization::getNextEnclosingRepetitiveRegion(
     Region *region, const BufferizationOptions &options) {
   assert(isRepetitiveRegion(region, options) && "expected repetitive region");

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 75295862bd6fa8..330c16eb23d79e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -383,13 +383,14 @@ static bool happensBefore(Operation *a, Operation *b,
 ///    regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE
 ///    or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.)
 ///
-bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
-                       const SetVector<Value> &definitions,
-                       const AnalysisState &state) {
+static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
+                              const SetVector<Value> &definitions,
+                              AnalysisState &state) {
   const BufferizationOptions &options = state.getOptions();
   for (Value def : definitions) {
-    Region *rRead = getEnclosingRepetitiveRegion(uRead->getOwner(), options);
-    Region *rDef = getEnclosingRepetitiveRegion(def, options);
+    Region *rRead =
+        state.getEnclosingRepetitiveRegion(uRead->getOwner(), options);
+    Region *rDef = state.getEnclosingRepetitiveRegion(def, options);
 
     // READ and DEF are in the same repetitive region. `happensBefore` can be
     // used to rule out RaW conflicts due to op ordering.
@@ -782,7 +783,10 @@ OneShotAnalysisState::findDefinitionsCached(Value value) {
   return cachedDefinitions[value];
 }
 
-void OneShotAnalysisState::resetCache() { cachedDefinitions.clear(); }
+void OneShotAnalysisState::resetCache() {
+  AnalysisState::resetCache();
+  cachedDefinitions.clear();
+}
 
 /// Determine if `operand` can be bufferized in-place.
 static LogicalResult


        


More information about the Mlir-commits mailing list