[Mlir-commits] [mlir] cd94f18 - [mlir] Pass AnalysisManager as optional parameter to analysis ctor, so it can request any other analysis as dependency
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 20 09:19:20 PDT 2021
Author: Butygin
Date: 2021-04-20T19:18:36+03:00
New Revision: cd94f18ec1ba5b5c0e19e9c7506b1de86651354a
URL: https://github.com/llvm/llvm-project/commit/cd94f18ec1ba5b5c0e19e9c7506b1de86651354a
DIFF: https://github.com/llvm/llvm-project/commit/cd94f18ec1ba5b5c0e19e9c7506b1de86651354a.diff
LOG: [mlir] Pass AnalysisManager as optional parameter to analysis ctor, so it can request any other analysis as dependency
Differential Revision: https://reviews.llvm.org/D100274
Added:
Modified:
mlir/docs/PassManagement.md
mlir/include/mlir/Pass/AnalysisManager.h
mlir/unittests/Pass/AnalysisManagerTest.cpp
Removed:
################################################################################
diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index e2568bd1395c2..da91da30e746d 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -159,7 +159,10 @@ not passes but free-standing classes that are computed lazily on-demand and
cached to avoid unnecessary recomputation. An analysis in MLIR must adhere to
the following:
-* Provide a valid constructor taking an `Operation*`.
+* Provide a valid constructor taking either an `Operation*` or `Operation*`
+ and `AnalysisManager &`.
+ * The provided `AnalysisManager &` should be used to query any necessary
+ analysis dependencies.
* Must not modify the given operation.
An analysis may provide additional hooks to control various behavior:
@@ -169,7 +172,9 @@ An analysis may provide additional hooks to control various behavior:
Given a preserved analysis set, the analysis returns true if it should truly be
invalidated. This allows for more fine-tuned invalidation in cases where an
analysis wasn't explicitly marked preserved, but may be preserved (or
-invalidated) based upon other properties such as analyses sets.
+invalidated) based upon other properties such as analyses sets. If the analysis
+uses any other analysis as a dependency, it must also check if the dependency
+was invalidated.
### Querying Analyses
@@ -200,6 +205,20 @@ struct MyOperationAnalysis {
MyOperationAnalysis(Operation *op);
};
+struct MyOperationAnalysisWithDependency {
+ MyOperationAnalysisWithDependency(Operation *op, AnalysisManager &am) {
+ // Request other analysis as dependency
+ MyOperationAnalysis &otherAnalysis = am.getAnalysis<MyOperationAnalysis>();
+ ...
+ }
+
+ bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) {
+ // Check if analysis or its dependency were invalidated
+ return !pa.isPreserved<MyOperationAnalysisWithDependency>() ||
+ !pa.isPreserved<MyOperationAnalysis>();
+ }
+};
+
void MyOperationPass::runOnOperation() {
// Query MyOperationAnalysis for the current operation.
MyOperationAnalysis &myAnalysis = getAnalysis<MyOperationAnalysis>();
@@ -899,6 +918,10 @@ the PassManager that observe various events:
executed, `runAfterPass` will *not* be.
* `runBeforeAnalysis`
* This callback is run just before an analysis is computed.
+ * If the analysis requested another analysis as a dependency, the
+ `runBeforeAnalysis`/`runAfterAnalysis` pair for the dependency can be
+ called from inside of the current `runBeforeAnalysis`/`runAfterAnalysis`
+ pair.
* `runAfterAnalysis`
* This callback is run right after an analysis is computed.
diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h
index 5da0c95d78dc7..21318b0097a49 100644
--- a/mlir/include/mlir/Pass/AnalysisManager.h
+++ b/mlir/include/mlir/Pass/AnalysisManager.h
@@ -13,10 +13,13 @@
#include "mlir/Pass/PassInstrumentation.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/TypeName.h"
namespace mlir {
+class AnalysisManager;
+
//===----------------------------------------------------------------------===//
// Analysis Preservation and Concept Modeling
//===----------------------------------------------------------------------===//
@@ -59,6 +62,16 @@ class PreservedAnalyses {
bool isPreserved(TypeID id) const { return preservedIDs.count(id); }
private:
+ /// Remove the analysis from preserved set.
+ template <typename AnalysisT>
+ void unpreserve() {
+ preservedIDs.erase(TypeID::get<AnalysisT>());
+ }
+
+ /// AnalysisModel need access to unpreserve().
+ template <typename>
+ friend struct AnalysisModel;
+
/// The set of analyses that are known to be preserved.
SmallPtrSet<TypeID, 2> preservedIDs;
};
@@ -91,8 +104,9 @@ struct AnalysisConcept {
/// set, returns true if it should truly be invalidated. This allows for more
/// fine-tuned invalidation in cases where an analysis wasn't explicitly
/// marked preserved, but may be preserved(or invalidated) based upon other
- /// properties such as analyses sets.
- virtual bool isInvalidated(const PreservedAnalyses &pa) = 0;
+ /// properties such as analyses sets. Invalidated analyses must also be
+ /// removed from pa.
+ virtual bool invalidate(PreservedAnalyses &pa) = 0;
};
/// A derived analysis model used to hold a specific analysis object.
@@ -101,9 +115,13 @@ template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
explicit AnalysisModel(Args &&...args)
: analysis(std::forward<Args>(args)...) {}
- /// A hook used to query analyses for invalidation.
- bool isInvalidated(const PreservedAnalyses &pa) final {
- return analysis_impl::isInvalidated(analysis, pa);
+ /// A hook used to query analyses for invalidation. Removes invalidated
+ /// analyses from pa.
+ bool invalidate(PreservedAnalyses &pa) final {
+ bool result = analysis_impl::isInvalidated(analysis, pa);
+ if (result)
+ pa.unpreserve<AnalysisT>();
+ return result;
}
/// The actual analysis object.
@@ -114,7 +132,7 @@ template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
/// computation, caching, and invalidation of analyses takes place here.
class AnalysisMap {
/// A mapping between an analysis id and an existing analysis instance.
- using ConceptMap = DenseMap<TypeID, std::unique_ptr<AnalysisConcept>>;
+ using ConceptMap = llvm::MapVector<TypeID, std::unique_ptr<AnalysisConcept>>;
/// Utility to return the name of the given analysis class.
template <typename AnalysisT> static StringRef getAnalysisName() {
@@ -129,17 +147,19 @@ class AnalysisMap {
/// Get an analysis for the current IR unit, computing it if necessary.
template <typename AnalysisT>
- AnalysisT &getAnalysis(PassInstrumentor *pi) {
- return getAnalysisImpl<AnalysisT, Operation *>(pi, ir);
+ AnalysisT &getAnalysis(PassInstrumentor *pi, AnalysisManager &am) {
+ return getAnalysisImpl<AnalysisT, Operation *>(pi, ir, am);
}
/// Get an analysis for the current IR unit assuming it's of specific derived
/// operation type.
template <typename AnalysisT, typename OpT>
- typename std::enable_if<std::is_constructible<AnalysisT, OpT>::value,
- AnalysisT &>::type
- getAnalysis(PassInstrumentor *pi) {
- return getAnalysisImpl<AnalysisT, OpT>(pi, cast<OpT>(ir));
+ std::enable_if_t<
+ std::is_constructible<AnalysisT, OpT>::value ||
+ std::is_constructible<AnalysisT, OpT, AnalysisManager &>::value,
+ AnalysisT &>
+ getAnalysis(PassInstrumentor *pi, AnalysisManager &am) {
+ return getAnalysisImpl<AnalysisT, OpT>(pi, cast<OpT>(ir), am);
}
/// Get a cached analysis instance if one exists, otherwise return null.
@@ -160,30 +180,31 @@ class AnalysisMap {
/// Invalidate any cached analyses based upon the given set of preserved
/// analyses.
void invalidate(const PreservedAnalyses &pa) {
+ PreservedAnalyses paCopy(pa);
// Remove any analyses that were invalidated.
- for (auto it = analyses.begin(), e = analyses.end(); it != e;) {
- auto curIt = it++;
- if (curIt->second->isInvalidated(pa))
- analyses.erase(curIt);
- }
+ // As we are using MapVector, order of insertion is preserved and
+ // dependencies always go before users, so we need only one iteration.
+ analyses.remove_if(
+ [&](auto &val) { return val.second->invalidate(paCopy); });
}
private:
template <typename AnalysisT, typename OpT>
- AnalysisT &getAnalysisImpl(PassInstrumentor *pi, OpT op) {
+ AnalysisT &getAnalysisImpl(PassInstrumentor *pi, OpT op,
+ AnalysisManager &am) {
TypeID id = TypeID::get<AnalysisT>();
- typename ConceptMap::iterator it;
- bool wasInserted;
- std::tie(it, wasInserted) = analyses.try_emplace(id);
-
+ auto it = analyses.find(id);
// If we don't have a cached analysis for this operation, compute it
// directly and add it to the cache.
- if (wasInserted) {
+ if (analyses.end() == it) {
if (pi)
pi->runBeforeAnalysis(getAnalysisName<AnalysisT>(), id, ir);
- it->second = std::make_unique<AnalysisModel<AnalysisT>>(op);
+ bool wasInserted;
+ std::tie(it, wasInserted) =
+ analyses.insert({id, constructAnalysis<AnalysisT>(am, op)});
+ assert(wasInserted);
if (pi)
pi->runAfterAnalysis(getAnalysisName<AnalysisT>(), id, ir);
@@ -191,6 +212,22 @@ class AnalysisMap {
return static_cast<AnalysisModel<AnalysisT> &>(*it->second).analysis;
}
+ /// Construct analysis using two arguments contructor (OpT, AnalysisManager)
+ template <typename AnalysisT, typename OpT,
+ std::enable_if_t<std::is_constructible<
+ AnalysisT, OpT, AnalysisManager &>::value> * = nullptr>
+ static auto constructAnalysis(AnalysisManager &am, OpT op) {
+ return std::make_unique<AnalysisModel<AnalysisT>>(op, am);
+ }
+
+ /// Construct analysis using single argument contructor (OpT)
+ template <typename AnalysisT, typename OpT,
+ std::enable_if_t<!std::is_constructible<
+ AnalysisT, OpT, AnalysisManager &>::value> * = nullptr>
+ static auto constructAnalysis(AnalysisManager &, OpT op) {
+ return std::make_unique<AnalysisModel<AnalysisT>>(op);
+ }
+
Operation *ir;
ConceptMap analyses;
};
@@ -273,14 +310,15 @@ class AnalysisManager {
/// Query for the given analysis for the current operation.
template <typename AnalysisT> AnalysisT &getAnalysis() {
- return impl->analyses.getAnalysis<AnalysisT>(getPassInstrumentor());
+ return impl->analyses.getAnalysis<AnalysisT>(getPassInstrumentor(), *this);
}
/// Query for the given analysis for the current operation of a specific
/// derived operation type.
template <typename AnalysisT, typename OpT>
AnalysisT &getAnalysis() {
- return impl->analyses.getAnalysis<AnalysisT, OpT>(getPassInstrumentor());
+ return impl->analyses.getAnalysis<AnalysisT, OpT>(getPassInstrumentor(),
+ *this);
}
/// Query for a cached entry of the given analysis on the current operation.
diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp
index 79ada798c4838..a5dcd18c82446 100644
--- a/mlir/unittests/Pass/AnalysisManagerTest.cpp
+++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp
@@ -159,4 +159,91 @@ TEST(AnalysisManagerTest, OpSpecificAnalysis) {
EXPECT_TRUE(am.getCachedAnalysis<OpSpecificAnalysis>().hasValue());
}
+struct AnalysisWithDependency {
+ AnalysisWithDependency(Operation *, AnalysisManager &am) {
+ am.getAnalysis<MyAnalysis>();
+ }
+
+ bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) {
+ return !pa.isPreserved<AnalysisWithDependency>() ||
+ !pa.isPreserved<MyAnalysis>();
+ }
+};
+
+TEST(AnalysisManagerTest, DependentAnalysis) {
+ MLIRContext context;
+
+ // Create a module.
+ OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
+ ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
+ AnalysisManager am = mam;
+
+ am.getAnalysis<AnalysisWithDependency>();
+ EXPECT_TRUE(am.getCachedAnalysis<AnalysisWithDependency>().hasValue());
+ EXPECT_TRUE(am.getCachedAnalysis<MyAnalysis>().hasValue());
+
+ detail::PreservedAnalyses pa;
+ pa.preserve<AnalysisWithDependency>();
+ am.invalidate(pa);
+
+ EXPECT_FALSE(am.getCachedAnalysis<AnalysisWithDependency>().hasValue());
+ EXPECT_FALSE(am.getCachedAnalysis<MyAnalysis>().hasValue());
+}
+
+struct AnalysisWithNestedDependency {
+ AnalysisWithNestedDependency(Operation *, AnalysisManager &am) {
+ am.getAnalysis<AnalysisWithDependency>();
+ }
+
+ bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) {
+ return !pa.isPreserved<AnalysisWithNestedDependency>() ||
+ !pa.isPreserved<AnalysisWithDependency>();
+ }
+};
+
+TEST(AnalysisManagerTest, NestedDependentAnalysis) {
+ MLIRContext context;
+
+ // Create a module.
+ OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
+ ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
+ AnalysisManager am = mam;
+
+ am.getAnalysis<AnalysisWithNestedDependency>();
+ EXPECT_TRUE(am.getCachedAnalysis<AnalysisWithNestedDependency>().hasValue());
+ EXPECT_TRUE(am.getCachedAnalysis<AnalysisWithDependency>().hasValue());
+ EXPECT_TRUE(am.getCachedAnalysis<MyAnalysis>().hasValue());
+
+ detail::PreservedAnalyses pa;
+ pa.preserve<AnalysisWithDependency>();
+ pa.preserve<AnalysisWithNestedDependency>();
+ am.invalidate(pa);
+
+ EXPECT_FALSE(am.getCachedAnalysis<AnalysisWithNestedDependency>().hasValue());
+ EXPECT_FALSE(am.getCachedAnalysis<AnalysisWithDependency>().hasValue());
+ EXPECT_FALSE(am.getCachedAnalysis<MyAnalysis>().hasValue());
+}
+
+struct AnalysisWith2Ctors {
+ AnalysisWith2Ctors(Operation *) { ctor1called = true; }
+
+ AnalysisWith2Ctors(Operation *, AnalysisManager &) { ctor2called = true; }
+
+ bool ctor1called = false;
+ bool ctor2called = false;
+};
+
+TEST(AnalysisManagerTest, DependentAnalysis2Ctors) {
+ MLIRContext context;
+
+ // Create a module.
+ OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
+ ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
+ AnalysisManager am = mam;
+
+ auto &an = am.getAnalysis<AnalysisWith2Ctors>();
+ EXPECT_FALSE(an.ctor1called);
+ EXPECT_TRUE(an.ctor2called);
+}
+
} // end namespace
More information about the Mlir-commits
mailing list