[Mlir-commits] [mlir] cd94f18 - [mlir] Pass AnalysisManager as optional parameter to analysis ctor, so it can request any other analysis as dependency

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 20 09:19:20 PDT 2021


Author: Butygin
Date: 2021-04-20T19:18:36+03:00
New Revision: cd94f18ec1ba5b5c0e19e9c7506b1de86651354a

URL: https://github.com/llvm/llvm-project/commit/cd94f18ec1ba5b5c0e19e9c7506b1de86651354a
DIFF: https://github.com/llvm/llvm-project/commit/cd94f18ec1ba5b5c0e19e9c7506b1de86651354a.diff

LOG: [mlir] Pass AnalysisManager as optional parameter to analysis ctor, so it can request any other analysis as dependency

Differential Revision: https://reviews.llvm.org/D100274

Added: 
    

Modified: 
    mlir/docs/PassManagement.md
    mlir/include/mlir/Pass/AnalysisManager.h
    mlir/unittests/Pass/AnalysisManagerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index e2568bd1395c2..da91da30e746d 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -159,7 +159,10 @@ not passes but free-standing classes that are computed lazily on-demand and
 cached to avoid unnecessary recomputation. An analysis in MLIR must adhere to
 the following:
 
-*   Provide a valid constructor taking an `Operation*`.
+*   Provide a valid constructor taking either an `Operation*` or `Operation*`
+    and `AnalysisManager &`.
+    *   The provided `AnalysisManager &` should be used to query any necessary
+        analysis dependencies.
 *   Must not modify the given operation.
 
 An analysis may provide additional hooks to control various behavior:
@@ -169,7 +172,9 @@ An analysis may provide additional hooks to control various behavior:
 Given a preserved analysis set, the analysis returns true if it should truly be
 invalidated. This allows for more fine-tuned invalidation in cases where an
 analysis wasn't explicitly marked preserved, but may be preserved (or
-invalidated) based upon other properties such as analyses sets.
+invalidated) based upon other properties such as analyses sets. If the analysis
+uses any other analysis as a dependency, it must also check if the dependency
+was invalidated.
 
 ### Querying Analyses
 
@@ -200,6 +205,20 @@ struct MyOperationAnalysis {
   MyOperationAnalysis(Operation *op);
 };
 
+struct MyOperationAnalysisWithDependency {
+  MyOperationAnalysisWithDependency(Operation *op, AnalysisManager &am) {
+    // Request other analysis as dependency
+    MyOperationAnalysis &otherAnalysis = am.getAnalysis<MyOperationAnalysis>();
+    ...
+  }
+
+  bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) {
+    // Check if analysis or its dependency were invalidated
+    return !pa.isPreserved<MyOperationAnalysisWithDependency>() ||
+           !pa.isPreserved<MyOperationAnalysis>();
+  }
+};
+
 void MyOperationPass::runOnOperation() {
   // Query MyOperationAnalysis for the current operation.
   MyOperationAnalysis &myAnalysis = getAnalysis<MyOperationAnalysis>();
@@ -899,6 +918,10 @@ the PassManager that observe various events:
         executed, `runAfterPass` will *not* be.
 *   `runBeforeAnalysis`
     *   This callback is run just before an analysis is computed.
+    *   If the analysis requested another analysis as a dependency, the
+        `runBeforeAnalysis`/`runAfterAnalysis` pair for the dependency can be
+        called from inside of the current `runBeforeAnalysis`/`runAfterAnalysis`
+        pair.
 *   `runAfterAnalysis`
     *   This callback is run right after an analysis is computed.
 

diff  --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h
index 5da0c95d78dc7..21318b0097a49 100644
--- a/mlir/include/mlir/Pass/AnalysisManager.h
+++ b/mlir/include/mlir/Pass/AnalysisManager.h
@@ -13,10 +13,13 @@
 #include "mlir/Pass/PassInstrumentation.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/TypeName.h"
 
 namespace mlir {
+class AnalysisManager;
+
 //===----------------------------------------------------------------------===//
 // Analysis Preservation and Concept Modeling
 //===----------------------------------------------------------------------===//
@@ -59,6 +62,16 @@ class PreservedAnalyses {
   bool isPreserved(TypeID id) const { return preservedIDs.count(id); }
 
 private:
+  /// Remove the analysis from preserved set.
+  template <typename AnalysisT>
+  void unpreserve() {
+    preservedIDs.erase(TypeID::get<AnalysisT>());
+  }
+
+  /// AnalysisModel need access to unpreserve().
+  template <typename>
+  friend struct AnalysisModel;
+
   /// The set of analyses that are known to be preserved.
   SmallPtrSet<TypeID, 2> preservedIDs;
 };
@@ -91,8 +104,9 @@ struct AnalysisConcept {
   /// set, returns true if it should truly be invalidated. This allows for more
   /// fine-tuned invalidation in cases where an analysis wasn't explicitly
   /// marked preserved, but may be preserved(or invalidated) based upon other
-  /// properties such as analyses sets.
-  virtual bool isInvalidated(const PreservedAnalyses &pa) = 0;
+  /// properties such as analyses sets. Invalidated analyses must also be
+  /// removed from pa.
+  virtual bool invalidate(PreservedAnalyses &pa) = 0;
 };
 
 /// A derived analysis model used to hold a specific analysis object.
@@ -101,9 +115,13 @@ template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
   explicit AnalysisModel(Args &&...args)
       : analysis(std::forward<Args>(args)...) {}
 
-  /// A hook used to query analyses for invalidation.
-  bool isInvalidated(const PreservedAnalyses &pa) final {
-    return analysis_impl::isInvalidated(analysis, pa);
+  /// A hook used to query analyses for invalidation. Removes invalidated
+  /// analyses from pa.
+  bool invalidate(PreservedAnalyses &pa) final {
+    bool result = analysis_impl::isInvalidated(analysis, pa);
+    if (result)
+      pa.unpreserve<AnalysisT>();
+    return result;
   }
 
   /// The actual analysis object.
@@ -114,7 +132,7 @@ template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
 /// computation, caching, and invalidation of analyses takes place here.
 class AnalysisMap {
   /// A mapping between an analysis id and an existing analysis instance.
-  using ConceptMap = DenseMap<TypeID, std::unique_ptr<AnalysisConcept>>;
+  using ConceptMap = llvm::MapVector<TypeID, std::unique_ptr<AnalysisConcept>>;
 
   /// Utility to return the name of the given analysis class.
   template <typename AnalysisT> static StringRef getAnalysisName() {
@@ -129,17 +147,19 @@ class AnalysisMap {
 
   /// Get an analysis for the current IR unit, computing it if necessary.
   template <typename AnalysisT>
-  AnalysisT &getAnalysis(PassInstrumentor *pi) {
-    return getAnalysisImpl<AnalysisT, Operation *>(pi, ir);
+  AnalysisT &getAnalysis(PassInstrumentor *pi, AnalysisManager &am) {
+    return getAnalysisImpl<AnalysisT, Operation *>(pi, ir, am);
   }
 
   /// Get an analysis for the current IR unit assuming it's of specific derived
   /// operation type.
   template <typename AnalysisT, typename OpT>
-  typename std::enable_if<std::is_constructible<AnalysisT, OpT>::value,
-                          AnalysisT &>::type
-  getAnalysis(PassInstrumentor *pi) {
-    return getAnalysisImpl<AnalysisT, OpT>(pi, cast<OpT>(ir));
+  std::enable_if_t<
+      std::is_constructible<AnalysisT, OpT>::value ||
+          std::is_constructible<AnalysisT, OpT, AnalysisManager &>::value,
+      AnalysisT &>
+  getAnalysis(PassInstrumentor *pi, AnalysisManager &am) {
+    return getAnalysisImpl<AnalysisT, OpT>(pi, cast<OpT>(ir), am);
   }
 
   /// Get a cached analysis instance if one exists, otherwise return null.
@@ -160,30 +180,31 @@ class AnalysisMap {
   /// Invalidate any cached analyses based upon the given set of preserved
   /// analyses.
   void invalidate(const PreservedAnalyses &pa) {
+    PreservedAnalyses paCopy(pa);
     // Remove any analyses that were invalidated.
-    for (auto it = analyses.begin(), e = analyses.end(); it != e;) {
-      auto curIt = it++;
-      if (curIt->second->isInvalidated(pa))
-        analyses.erase(curIt);
-    }
+    // As we are using MapVector, order of insertion is preserved and
+    // dependencies always go before users, so we need only one iteration.
+    analyses.remove_if(
+        [&](auto &val) { return val.second->invalidate(paCopy); });
   }
 
 private:
   template <typename AnalysisT, typename OpT>
-  AnalysisT &getAnalysisImpl(PassInstrumentor *pi, OpT op) {
+  AnalysisT &getAnalysisImpl(PassInstrumentor *pi, OpT op,
+                             AnalysisManager &am) {
     TypeID id = TypeID::get<AnalysisT>();
 
-    typename ConceptMap::iterator it;
-    bool wasInserted;
-    std::tie(it, wasInserted) = analyses.try_emplace(id);
-
+    auto it = analyses.find(id);
     // If we don't have a cached analysis for this operation, compute it
     // directly and add it to the cache.
-    if (wasInserted) {
+    if (analyses.end() == it) {
       if (pi)
         pi->runBeforeAnalysis(getAnalysisName<AnalysisT>(), id, ir);
 
-      it->second = std::make_unique<AnalysisModel<AnalysisT>>(op);
+      bool wasInserted;
+      std::tie(it, wasInserted) =
+          analyses.insert({id, constructAnalysis<AnalysisT>(am, op)});
+      assert(wasInserted);
 
       if (pi)
         pi->runAfterAnalysis(getAnalysisName<AnalysisT>(), id, ir);
@@ -191,6 +212,22 @@ class AnalysisMap {
     return static_cast<AnalysisModel<AnalysisT> &>(*it->second).analysis;
   }
 
+  /// Construct analysis using two arguments contructor (OpT, AnalysisManager)
+  template <typename AnalysisT, typename OpT,
+            std::enable_if_t<std::is_constructible<
+                AnalysisT, OpT, AnalysisManager &>::value> * = nullptr>
+  static auto constructAnalysis(AnalysisManager &am, OpT op) {
+    return std::make_unique<AnalysisModel<AnalysisT>>(op, am);
+  }
+
+  /// Construct analysis using single argument contructor (OpT)
+  template <typename AnalysisT, typename OpT,
+            std::enable_if_t<!std::is_constructible<
+                AnalysisT, OpT, AnalysisManager &>::value> * = nullptr>
+  static auto constructAnalysis(AnalysisManager &, OpT op) {
+    return std::make_unique<AnalysisModel<AnalysisT>>(op);
+  }
+
   Operation *ir;
   ConceptMap analyses;
 };
@@ -273,14 +310,15 @@ class AnalysisManager {
 
   /// Query for the given analysis for the current operation.
   template <typename AnalysisT> AnalysisT &getAnalysis() {
-    return impl->analyses.getAnalysis<AnalysisT>(getPassInstrumentor());
+    return impl->analyses.getAnalysis<AnalysisT>(getPassInstrumentor(), *this);
   }
 
   /// Query for the given analysis for the current operation of a specific
   /// derived operation type.
   template <typename AnalysisT, typename OpT>
   AnalysisT &getAnalysis() {
-    return impl->analyses.getAnalysis<AnalysisT, OpT>(getPassInstrumentor());
+    return impl->analyses.getAnalysis<AnalysisT, OpT>(getPassInstrumentor(),
+                                                      *this);
   }
 
   /// Query for a cached entry of the given analysis on the current operation.

diff  --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp
index 79ada798c4838..a5dcd18c82446 100644
--- a/mlir/unittests/Pass/AnalysisManagerTest.cpp
+++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp
@@ -159,4 +159,91 @@ TEST(AnalysisManagerTest, OpSpecificAnalysis) {
   EXPECT_TRUE(am.getCachedAnalysis<OpSpecificAnalysis>().hasValue());
 }
 
+struct AnalysisWithDependency {
+  AnalysisWithDependency(Operation *, AnalysisManager &am) {
+    am.getAnalysis<MyAnalysis>();
+  }
+
+  bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) {
+    return !pa.isPreserved<AnalysisWithDependency>() ||
+           !pa.isPreserved<MyAnalysis>();
+  }
+};
+
+TEST(AnalysisManagerTest, DependentAnalysis) {
+  MLIRContext context;
+
+  // Create a module.
+  OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
+  ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
+  AnalysisManager am = mam;
+
+  am.getAnalysis<AnalysisWithDependency>();
+  EXPECT_TRUE(am.getCachedAnalysis<AnalysisWithDependency>().hasValue());
+  EXPECT_TRUE(am.getCachedAnalysis<MyAnalysis>().hasValue());
+
+  detail::PreservedAnalyses pa;
+  pa.preserve<AnalysisWithDependency>();
+  am.invalidate(pa);
+
+  EXPECT_FALSE(am.getCachedAnalysis<AnalysisWithDependency>().hasValue());
+  EXPECT_FALSE(am.getCachedAnalysis<MyAnalysis>().hasValue());
+}
+
+struct AnalysisWithNestedDependency {
+  AnalysisWithNestedDependency(Operation *, AnalysisManager &am) {
+    am.getAnalysis<AnalysisWithDependency>();
+  }
+
+  bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) {
+    return !pa.isPreserved<AnalysisWithNestedDependency>() ||
+           !pa.isPreserved<AnalysisWithDependency>();
+  }
+};
+
+TEST(AnalysisManagerTest, NestedDependentAnalysis) {
+  MLIRContext context;
+
+  // Create a module.
+  OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
+  ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
+  AnalysisManager am = mam;
+
+  am.getAnalysis<AnalysisWithNestedDependency>();
+  EXPECT_TRUE(am.getCachedAnalysis<AnalysisWithNestedDependency>().hasValue());
+  EXPECT_TRUE(am.getCachedAnalysis<AnalysisWithDependency>().hasValue());
+  EXPECT_TRUE(am.getCachedAnalysis<MyAnalysis>().hasValue());
+
+  detail::PreservedAnalyses pa;
+  pa.preserve<AnalysisWithDependency>();
+  pa.preserve<AnalysisWithNestedDependency>();
+  am.invalidate(pa);
+
+  EXPECT_FALSE(am.getCachedAnalysis<AnalysisWithNestedDependency>().hasValue());
+  EXPECT_FALSE(am.getCachedAnalysis<AnalysisWithDependency>().hasValue());
+  EXPECT_FALSE(am.getCachedAnalysis<MyAnalysis>().hasValue());
+}
+
+struct AnalysisWith2Ctors {
+  AnalysisWith2Ctors(Operation *) { ctor1called = true; }
+
+  AnalysisWith2Ctors(Operation *, AnalysisManager &) { ctor2called = true; }
+
+  bool ctor1called = false;
+  bool ctor2called = false;
+};
+
+TEST(AnalysisManagerTest, DependentAnalysis2Ctors) {
+  MLIRContext context;
+
+  // Create a module.
+  OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
+  ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
+  AnalysisManager am = mam;
+
+  auto &an = am.getAnalysis<AnalysisWith2Ctors>();
+  EXPECT_FALSE(an.ctor1called);
+  EXPECT_TRUE(an.ctor2called);
+}
+
 } // end namespace


        


More information about the Mlir-commits mailing list