[Mlir-commits] [mlir] 9a4b30c - [MLIR] Add support for defining and using Op specific analysis

Rahul Joshi llvmlistbot at llvm.org
Mon Aug 17 09:01:16 PDT 2020


Author: Rahul Joshi
Date: 2020-08-17T09:00:47-07:00
New Revision: 9a4b30cf84298887f0a7bf70b865493d767abdc9

URL: https://github.com/llvm/llvm-project/commit/9a4b30cf84298887f0a7bf70b865493d767abdc9
DIFF: https://github.com/llvm/llvm-project/commit/9a4b30cf84298887f0a7bf70b865493d767abdc9.diff

LOG: [MLIR] Add support for defining and using Op specific analysis

- Add variants of getAnalysis() and friends that operate on a specific derived
  operation types.
- Add OpPassManager::getAnalysis() to always call the base getAnalysis() with OpT.
- With this, an OperationPass can call getAnalysis<> using an analysis type that
  is generic (works on Operation *) or specific to the OpT for the pass. Anything
  else will fail to compile.
- Extend AnalysisManager unit test to test this, and add a new PassManager unit
  test to test this functionality in the context of an OperationPass.

Differential Revision: https://reviews.llvm.org/D84897

Added: 
    mlir/unittests/Pass/PassManagerTest.cpp

Modified: 
    mlir/include/mlir/Pass/AnalysisManager.h
    mlir/include/mlir/Pass/Pass.h
    mlir/unittests/Pass/AnalysisManagerTest.cpp
    mlir/unittests/Pass/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h
index 4e9b3f20c929..37036e2298d0 100644
--- a/mlir/include/mlir/Pass/AnalysisManager.h
+++ b/mlir/include/mlir/Pass/AnalysisManager.h
@@ -128,25 +128,18 @@ class AnalysisMap {
   explicit AnalysisMap(Operation *ir) : ir(ir) {}
 
   /// Get an analysis for the current IR unit, computing it if necessary.
-  template <typename AnalysisT> AnalysisT &getAnalysis(PassInstrumentor *pi) {
-    TypeID id = TypeID::get<AnalysisT>();
-
-    typename ConceptMap::iterator it;
-    bool wasInserted;
-    std::tie(it, wasInserted) = analyses.try_emplace(id);
-
-    // If we don't have a cached analysis for this function, compute it directly
-    // and add it to the cache.
-    if (wasInserted) {
-      if (pi)
-        pi->runBeforeAnalysis(getAnalysisName<AnalysisT>(), id, ir);
-
-      it->second = std::make_unique<AnalysisModel<AnalysisT>>(ir);
+  template <typename AnalysisT>
+  AnalysisT &getAnalysis(PassInstrumentor *pi) {
+    return getAnalysisImpl<AnalysisT, Operation *>(pi, ir);
+  }
 
-      if (pi)
-        pi->runAfterAnalysis(getAnalysisName<AnalysisT>(), id, ir);
-    }
-    return static_cast<AnalysisModel<AnalysisT> &>(*it->second).analysis;
+  /// Get an analysis for the current IR unit assuming it's of specific derived
+  /// operation type.
+  template <typename AnalysisT, typename OpT>
+  typename std::enable_if<std::is_constructible<AnalysisT, OpT>::value,
+                          AnalysisT &>::type
+  getAnalysis(PassInstrumentor *pi) {
+    return getAnalysisImpl<AnalysisT, OpT>(pi, cast<OpT>(ir));
   }
 
   /// Get a cached analysis instance if one exists, otherwise return null.
@@ -176,6 +169,28 @@ class AnalysisMap {
   }
 
 private:
+  template <typename AnalysisT, typename OpT>
+  AnalysisT &getAnalysisImpl(PassInstrumentor *pi, OpT op) {
+    TypeID id = TypeID::get<AnalysisT>();
+
+    typename ConceptMap::iterator it;
+    bool wasInserted;
+    std::tie(it, wasInserted) = analyses.try_emplace(id);
+
+    // If we don't have a cached analysis for this function, compute it directly
+    // and add it to the cache.
+    if (wasInserted) {
+      if (pi)
+        pi->runBeforeAnalysis(getAnalysisName<AnalysisT>(), id, ir);
+
+      it->second = std::make_unique<AnalysisModel<AnalysisT>>(op);
+
+      if (pi)
+        pi->runAfterAnalysis(getAnalysisName<AnalysisT>(), id, ir);
+    }
+    return static_cast<AnalysisModel<AnalysisT> &>(*it->second).analysis;
+  }
+
   Operation *ir;
   ConceptMap analyses;
 };
@@ -216,8 +231,8 @@ class AnalysisManager {
 public:
   using PreservedAnalyses = detail::PreservedAnalyses;
 
-  // Query for a cached analysis on the given parent operation. The analysis may
-  // not exist and if it does it may be out-of-date.
+  /// Query for a cached analysis on the given parent operation. The analysis
+  /// may not exist and if it does it may be out-of-date.
   template <typename AnalysisT>
   Optional<std::reference_wrapper<AnalysisT>>
   getCachedParentAnalysis(Operation *parentOp) const {
@@ -230,12 +245,19 @@ class AnalysisManager {
     return None;
   }
 
-  // Query for the given analysis for the current operation.
+  /// Query for the given analysis for the current operation.
   template <typename AnalysisT> AnalysisT &getAnalysis() {
     return impl->analyses.getAnalysis<AnalysisT>(getPassInstrumentor());
   }
 
-  // Query for a cached entry of the given analysis on the current operation.
+  /// Query for the given analysis for the current operation of a specific
+  /// derived operation type.
+  template <typename AnalysisT, typename OpT>
+  AnalysisT &getAnalysis() {
+    return impl->analyses.getAnalysis<AnalysisT, OpT>(getPassInstrumentor());
+  }
+
+  /// Query for a cached entry of the given analysis on the current operation.
   template <typename AnalysisT>
   Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
     return impl->analyses.getCachedAnalysis<AnalysisT>();
@@ -246,6 +268,13 @@ class AnalysisManager {
     return slice(op).template getAnalysis<AnalysisT>();
   }
 
+  /// Query for an analysis of a child operation of a specifc derived operation
+  /// type, constructing it if necessary.
+  template <typename AnalysisT, typename OpT>
+  AnalysisT &getChildAnalysis(OpT child) {
+    return slice(child).template getAnalysis<AnalysisT, OpT>();
+  }
+
   /// Query for a cached analysis of a child operation, or return null.
   template <typename AnalysisT>
   Optional<std::reference_wrapper<AnalysisT>>

diff  --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 7c0f9bd958a1..8de31d944319 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -167,6 +167,13 @@ class Pass {
     return getAnalysisManager().getAnalysis<AnalysisT>();
   }
 
+  /// Query an analysis for the current ir unit of a specific derived operation
+  /// type.
+  template <typename AnalysisT, typename OpT>
+  AnalysisT &getAnalysis() {
+    return getAnalysisManager().getAnalysis<AnalysisT, OpT>();
+  }
+
   /// Query a cached instance of an analysis for the current ir unit if one
   /// exists.
   template <typename AnalysisT>
@@ -187,12 +194,14 @@ class Pass {
     getPassState().preservedAnalyses.preserve(id);
   }
 
-  /// Returns the analysis for the parent operation if it exists.
+  /// Returns the analysis for the given parent operation if it exists.
   template <typename AnalysisT>
   Optional<std::reference_wrapper<AnalysisT>>
   getCachedParentAnalysis(Operation *parent) {
     return getAnalysisManager().getCachedParentAnalysis<AnalysisT>(parent);
   }
+
+  /// Returns the analysis for the parent operation if it exists.
   template <typename AnalysisT>
   Optional<std::reference_wrapper<AnalysisT>> getCachedParentAnalysis() {
     return getAnalysisManager().getCachedParentAnalysis<AnalysisT>(
@@ -212,6 +221,13 @@ class Pass {
     return getAnalysisManager().getChildAnalysis<AnalysisT>(child);
   }
 
+  /// Returns the analysis for the given child operation of specific derived
+  /// operation type, or creates it if it doesn't exist.
+  template <typename AnalysisT, typename OpTy>
+  AnalysisT &getChildAnalysis(OpTy child) {
+    return getAnalysisManager().getChildAnalysis<AnalysisT>(child);
+  }
+
   /// Returns the current analysis manager.
   AnalysisManager getAnalysisManager() {
     return getPassState().analysisManager;
@@ -286,6 +302,13 @@ template <typename OpT = void> class OperationPass : public Pass {
 
   /// Return the current operation being transformed.
   OpT getOperation() { return cast<OpT>(Pass::getOperation()); }
+
+  /// Query an analysis for the current operation of the specific derived
+  /// operation type.
+  template <typename AnalysisT>
+  AnalysisT &getAnalysis() {
+    return Pass::getAnalysis<AnalysisT, OpT>();
+  }
 };
 
 /// Pass to transform an operation.

diff  --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp
index a99df3911a5d..41a90649deef 100644
--- a/mlir/unittests/Pass/AnalysisManagerTest.cpp
+++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp
@@ -9,6 +9,8 @@
 #include "mlir/Pass/AnalysisManager.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
 #include "gtest/gtest.h"
 
 using namespace mlir;
@@ -22,6 +24,9 @@ struct MyAnalysis {
 struct OtherAnalysis {
   OtherAnalysis(Operation *) {}
 };
+struct OpSpecificAnalysis {
+  OpSpecificAnalysis(ModuleOp) {}
+};
 
 TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
   MLIRContext context;
@@ -138,4 +143,18 @@ TEST(AnalysisManagerTest, CustomInvalidation) {
   am.invalidate(pa);
   EXPECT_TRUE(am.getCachedAnalysis<CustomInvalidatingAnalysis>().hasValue());
 }
+
+TEST(AnalysisManagerTest, OpSpecificAnalysis) {
+  MLIRContext context;
+
+  // Create a module.
+  OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
+  ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
+  AnalysisManager am = mam;
+
+  // Query the op specific analysis for the module and verify that its cached.
+  am.getAnalysis<OpSpecificAnalysis, ModuleOp>();
+  EXPECT_TRUE(am.getCachedAnalysis<OpSpecificAnalysis>().hasValue());
+}
+
 } // end namespace

diff  --git a/mlir/unittests/Pass/CMakeLists.txt b/mlir/unittests/Pass/CMakeLists.txt
index a5aaee378f33..52cee34cee65 100644
--- a/mlir/unittests/Pass/CMakeLists.txt
+++ b/mlir/unittests/Pass/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRPassTests
   AnalysisManagerTest.cpp
+  PassManagerTest.cpp
 )
 target_link_libraries(MLIRPassTests
   PRIVATE

diff  --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
new file mode 100644
index 000000000000..29086a2994e8
--- /dev/null
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -0,0 +1,77 @@
+//===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/PassManager.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+namespace {
+/// Analysis that operates on any operation.
+struct GenericAnalysis {
+  GenericAnalysis(Operation *op) : isFunc(isa<FuncOp>(op)) {}
+  const bool isFunc;
+};
+
+/// Analysis that operates on a specific operation.
+struct OpSpecificAnalysis {
+  OpSpecificAnalysis(FuncOp op) : isSecret(op.getName() == "secret") {}
+  const bool isSecret;
+};
+
+/// Simple pass to annotate a FuncOp with the results of analysis.
+/// Note: not using FunctionPass as it skip external functions.
+struct AnnotateFunctionPass
+    : public PassWrapper<AnnotateFunctionPass, OperationPass<FuncOp>> {
+  void runOnOperation() override {
+    FuncOp op = getOperation();
+    Builder builder(op.getParentOfType<ModuleOp>());
+
+    auto &ga = getAnalysis<GenericAnalysis>();
+    auto &sa = getAnalysis<OpSpecificAnalysis>();
+
+    op.setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
+    op.setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
+  }
+};
+
+TEST(PassManagerTest, OpSpecificAnalysis) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  // Create a module with 2 functions.
+  OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
+  for (StringRef name : {"secret", "not_secret"}) {
+    FuncOp func =
+        FuncOp::create(builder.getUnknownLoc(), name,
+                       builder.getFunctionType(llvm::None, llvm::None));
+    module->push_back(func);
+  }
+
+  // Instantiate and run our pass.
+  PassManager pm(&context);
+  pm.addNestedPass<FuncOp>(std::make_unique<AnnotateFunctionPass>());
+  LogicalResult result = pm.run(module.get());
+  EXPECT_TRUE(succeeded(result));
+
+  // Verify that each function got annotated with expected attributes.
+  for (FuncOp func : module->getOps<FuncOp>()) {
+    ASSERT_TRUE(func.getAttr("isFunc").isa<BoolAttr>());
+    EXPECT_TRUE(func.getAttr("isFunc").cast<BoolAttr>().getValue());
+
+    bool isSecret = func.getName() == "secret";
+    ASSERT_TRUE(func.getAttr("isSecret").isa<BoolAttr>());
+    EXPECT_EQ(func.getAttr("isSecret").cast<BoolAttr>().getValue(), isSecret);
+  }
+}
+
+} // end namespace


        


More information about the Mlir-commits mailing list