[Mlir-commits] [mlir] 9312b4f - [mlir][bufferization] Cache enclosing repetitive region
Martin Erhart
llvmlistbot at llvm.org
Sat Jul 8 02:33:35 PDT 2023
Author: Martin Erhart
Date: 2023-07-08T09:30:41Z
New Revision: 9312b4f90fcd9e1bf0186b66912c3b83c2d35f51
URL: https://github.com/llvm/llvm-project/commit/9312b4f90fcd9e1bf0186b66912c3b83c2d35f51
DIFF: https://github.com/llvm/llvm-project/commit/9312b4f90fcd9e1bf0186b66912c3b83c2d35f51.diff
LOG: [mlir][bufferization] Cache enclosing repetitive region
The `getEnclosingRepetitiveRegion` functions walk the ancestor regions everytime which can be expensive especially when there are multiple regions inbetween. This commit adds a cache to the bufferization analysis to remember the result of the walk.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D154710
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 45d705c444a7e8..d1faaf56c6afc4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -12,6 +12,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMapInfoVariant.h"
#include "llvm/ADT/SetVector.h"
#include <optional>
@@ -546,6 +547,21 @@ class AnalysisState {
TypeID getType() const { return type; }
+ /// Return the closest enclosing repetitive region around the given op.
+ Region *getEnclosingRepetitiveRegion(Operation *op,
+ const BufferizationOptions &options);
+
+ /// Return the closest enclosing repetitive region around the place where the
+ /// given value is defined.
+ Region *getEnclosingRepetitiveRegion(Value value,
+ const BufferizationOptions &options);
+
+ /// Return the closest enclosing repetitive region around the given block.
+ Region *getEnclosingRepetitiveRegion(Block *block,
+ const BufferizationOptions &options);
+
+ virtual void resetCache();
+
protected:
AnalysisState(const BufferizationOptions &options, TypeID type);
@@ -555,6 +571,10 @@ class AnalysisState {
/// The type of analysis.
TypeID type;
+
+ /// Cache containing closest ancestor repetitive Region.
+ DenseMap<std::variant<Operation *, Block *, Region *, Value>, Region *>
+ enclosingRepetitiveRegionCache;
};
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
@@ -652,19 +672,6 @@ getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
/// owner of the block. In case of an OpResult that is the defining op.
Operation *getOwnerOfValue(Value value);
-/// Return the closest enclosing repetitive region around the given op.
-Region *getEnclosingRepetitiveRegion(Operation *op,
- const BufferizationOptions &options);
-
-/// Return the closest enclosing repetitive region around the place where the
-/// given value is defined.
-Region *getEnclosingRepetitiveRegion(Value value,
- const BufferizationOptions &options);
-
-/// Return the closest enclosing repetitive region around the given block.
-Region *getEnclosingRepetitiveRegion(Block *block,
- const BufferizationOptions &options);
-
/// Assuming that the given region is repetitive, find the next enclosing
/// repetitive region.
Region *getNextEnclosingRepetitiveRegion(Region *region,
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 4fd3da1548c410..585c7ca92c7189 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -130,7 +130,7 @@ class OneShotAnalysisState : public AnalysisState {
const SetVector<Value> &findDefinitionsCached(Value value);
/// Reset cached data structures.
- void resetCache();
+ void resetCache() override;
/// Union the alias sets of `v1` and `v2`.
void unionAliasSets(Value v1, Value v2);
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 9a6f85a625003f..d9c334983ad814 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -50,36 +50,64 @@ static bool isRepetitiveRegion(Region *region,
return false;
}
-Region *bufferization::getEnclosingRepetitiveRegion(
+Region *AnalysisState::getEnclosingRepetitiveRegion(
Operation *op, const BufferizationOptions &options) {
if (!op->getBlock())
return nullptr;
- return getEnclosingRepetitiveRegion(op->getBlock(), options);
+ if (auto iter = enclosingRepetitiveRegionCache.find_as(op);
+ iter != enclosingRepetitiveRegionCache.end())
+ return iter->second;
+ return enclosingRepetitiveRegionCache[op] =
+ getEnclosingRepetitiveRegion(op->getBlock(), options);
}
-Region *bufferization::getEnclosingRepetitiveRegion(
+Region *AnalysisState::getEnclosingRepetitiveRegion(
Value value, const BufferizationOptions &options) {
+ if (auto iter = enclosingRepetitiveRegionCache.find_as(value);
+ iter != enclosingRepetitiveRegionCache.end())
+ return iter->second;
+
Region *region = value.getParentRegion();
+ // Collect all visited regions since we only know the repetitive region we
+ // want to map it to later on
+ SmallVector<Region *> visitedRegions;
while (region) {
+ visitedRegions.push_back(region);
if (isRepetitiveRegion(region, options))
- return region;
+ break;
region = region->getParentRegion();
}
- return nullptr;
+ enclosingRepetitiveRegionCache[value] = region;
+ for (Region *r : visitedRegions)
+ enclosingRepetitiveRegionCache[r] = region;
+ return region;
}
-Region *bufferization::getEnclosingRepetitiveRegion(
+Region *AnalysisState::getEnclosingRepetitiveRegion(
Block *block, const BufferizationOptions &options) {
+ if (auto iter = enclosingRepetitiveRegionCache.find_as(block);
+ iter != enclosingRepetitiveRegionCache.end())
+ return iter->second;
+
Region *region = block->getParent();
Operation *op = nullptr;
+ // Collect all visited regions since we only know the repetitive region we
+ // want to map it to later on
+ SmallVector<Region *> visitedRegions;
do {
op = region->getParentOp();
if (isRepetitiveRegion(region, options))
- return region;
+ break;
} while ((region = op->getParentRegion()));
- return nullptr;
+
+ enclosingRepetitiveRegionCache[block] = region;
+ for (Region *r : visitedRegions)
+ enclosingRepetitiveRegionCache[r] = region;
+ return region;
}
+void AnalysisState::resetCache() { enclosingRepetitiveRegionCache.clear(); }
+
Region *bufferization::getNextEnclosingRepetitiveRegion(
Region *region, const BufferizationOptions &options) {
assert(isRepetitiveRegion(region, options) && "expected repetitive region");
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 75295862bd6fa8..330c16eb23d79e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -383,13 +383,14 @@ static bool happensBefore(Operation *a, Operation *b,
/// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE
/// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.)
///
-bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
- const SetVector<Value> &definitions,
- const AnalysisState &state) {
+static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
+ const SetVector<Value> &definitions,
+ AnalysisState &state) {
const BufferizationOptions &options = state.getOptions();
for (Value def : definitions) {
- Region *rRead = getEnclosingRepetitiveRegion(uRead->getOwner(), options);
- Region *rDef = getEnclosingRepetitiveRegion(def, options);
+ Region *rRead =
+ state.getEnclosingRepetitiveRegion(uRead->getOwner(), options);
+ Region *rDef = state.getEnclosingRepetitiveRegion(def, options);
// READ and DEF are in the same repetitive region. `happensBefore` can be
// used to rule out RaW conflicts due to op ordering.
@@ -782,7 +783,10 @@ OneShotAnalysisState::findDefinitionsCached(Value value) {
return cachedDefinitions[value];
}
-void OneShotAnalysisState::resetCache() { cachedDefinitions.clear(); }
+void OneShotAnalysisState::resetCache() {
+ AnalysisState::resetCache();
+ cachedDefinitions.clear();
+}
/// Determine if `operand` can be bufferized in-place.
static LogicalResult
More information about the Mlir-commits
mailing list