[Mlir-commits] [mlir] 9a4b30c - [MLIR] Add support for defining and using Op specific analysis
Rahul Joshi
llvmlistbot at llvm.org
Mon Aug 17 09:01:16 PDT 2020
Author: Rahul Joshi
Date: 2020-08-17T09:00:47-07:00
New Revision: 9a4b30cf84298887f0a7bf70b865493d767abdc9
URL: https://github.com/llvm/llvm-project/commit/9a4b30cf84298887f0a7bf70b865493d767abdc9
DIFF: https://github.com/llvm/llvm-project/commit/9a4b30cf84298887f0a7bf70b865493d767abdc9.diff
LOG: [MLIR] Add support for defining and using Op specific analysis
- Add variants of getAnalysis() and friends that operate on a specific derived
operation types.
- Add OpPassManager::getAnalysis() to always call the base getAnalysis() with OpT.
- With this, an OperationPass can call getAnalysis<> using an analysis type that
is generic (works on Operation *) or specific to the OpT for the pass. Anything
else will fail to compile.
- Extend AnalysisManager unit test to test this, and add a new PassManager unit
test to test this functionality in the context of an OperationPass.
Differential Revision: https://reviews.llvm.org/D84897
Added:
mlir/unittests/Pass/PassManagerTest.cpp
Modified:
mlir/include/mlir/Pass/AnalysisManager.h
mlir/include/mlir/Pass/Pass.h
mlir/unittests/Pass/AnalysisManagerTest.cpp
mlir/unittests/Pass/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h
index 4e9b3f20c929..37036e2298d0 100644
--- a/mlir/include/mlir/Pass/AnalysisManager.h
+++ b/mlir/include/mlir/Pass/AnalysisManager.h
@@ -128,25 +128,18 @@ class AnalysisMap {
explicit AnalysisMap(Operation *ir) : ir(ir) {}
/// Get an analysis for the current IR unit, computing it if necessary.
- template <typename AnalysisT> AnalysisT &getAnalysis(PassInstrumentor *pi) {
- TypeID id = TypeID::get<AnalysisT>();
-
- typename ConceptMap::iterator it;
- bool wasInserted;
- std::tie(it, wasInserted) = analyses.try_emplace(id);
-
- // If we don't have a cached analysis for this function, compute it directly
- // and add it to the cache.
- if (wasInserted) {
- if (pi)
- pi->runBeforeAnalysis(getAnalysisName<AnalysisT>(), id, ir);
-
- it->second = std::make_unique<AnalysisModel<AnalysisT>>(ir);
+ template <typename AnalysisT>
+ AnalysisT &getAnalysis(PassInstrumentor *pi) {
+ return getAnalysisImpl<AnalysisT, Operation *>(pi, ir);
+ }
- if (pi)
- pi->runAfterAnalysis(getAnalysisName<AnalysisT>(), id, ir);
- }
- return static_cast<AnalysisModel<AnalysisT> &>(*it->second).analysis;
+ /// Get an analysis for the current IR unit assuming it's of specific derived
+ /// operation type.
+ template <typename AnalysisT, typename OpT>
+ typename std::enable_if<std::is_constructible<AnalysisT, OpT>::value,
+ AnalysisT &>::type
+ getAnalysis(PassInstrumentor *pi) {
+ return getAnalysisImpl<AnalysisT, OpT>(pi, cast<OpT>(ir));
}
/// Get a cached analysis instance if one exists, otherwise return null.
@@ -176,6 +169,28 @@ class AnalysisMap {
}
private:
+ template <typename AnalysisT, typename OpT>
+ AnalysisT &getAnalysisImpl(PassInstrumentor *pi, OpT op) {
+ TypeID id = TypeID::get<AnalysisT>();
+
+ typename ConceptMap::iterator it;
+ bool wasInserted;
+ std::tie(it, wasInserted) = analyses.try_emplace(id);
+
+ // If we don't have a cached analysis for this function, compute it directly
+ // and add it to the cache.
+ if (wasInserted) {
+ if (pi)
+ pi->runBeforeAnalysis(getAnalysisName<AnalysisT>(), id, ir);
+
+ it->second = std::make_unique<AnalysisModel<AnalysisT>>(op);
+
+ if (pi)
+ pi->runAfterAnalysis(getAnalysisName<AnalysisT>(), id, ir);
+ }
+ return static_cast<AnalysisModel<AnalysisT> &>(*it->second).analysis;
+ }
+
Operation *ir;
ConceptMap analyses;
};
@@ -216,8 +231,8 @@ class AnalysisManager {
public:
using PreservedAnalyses = detail::PreservedAnalyses;
- // Query for a cached analysis on the given parent operation. The analysis may
- // not exist and if it does it may be out-of-date.
+ /// Query for a cached analysis on the given parent operation. The analysis
+ /// may not exist and if it does it may be out-of-date.
template <typename AnalysisT>
Optional<std::reference_wrapper<AnalysisT>>
getCachedParentAnalysis(Operation *parentOp) const {
@@ -230,12 +245,19 @@ class AnalysisManager {
return None;
}
- // Query for the given analysis for the current operation.
+ /// Query for the given analysis for the current operation.
template <typename AnalysisT> AnalysisT &getAnalysis() {
return impl->analyses.getAnalysis<AnalysisT>(getPassInstrumentor());
}
- // Query for a cached entry of the given analysis on the current operation.
+ /// Query for the given analysis for the current operation of a specific
+ /// derived operation type.
+ template <typename AnalysisT, typename OpT>
+ AnalysisT &getAnalysis() {
+ return impl->analyses.getAnalysis<AnalysisT, OpT>(getPassInstrumentor());
+ }
+
+ /// Query for a cached entry of the given analysis on the current operation.
template <typename AnalysisT>
Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
return impl->analyses.getCachedAnalysis<AnalysisT>();
@@ -246,6 +268,13 @@ class AnalysisManager {
return slice(op).template getAnalysis<AnalysisT>();
}
+ /// Query for an analysis of a child operation of a specifc derived operation
+ /// type, constructing it if necessary.
+ template <typename AnalysisT, typename OpT>
+ AnalysisT &getChildAnalysis(OpT child) {
+ return slice(child).template getAnalysis<AnalysisT, OpT>();
+ }
+
/// Query for a cached analysis of a child operation, or return null.
template <typename AnalysisT>
Optional<std::reference_wrapper<AnalysisT>>
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 7c0f9bd958a1..8de31d944319 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -167,6 +167,13 @@ class Pass {
return getAnalysisManager().getAnalysis<AnalysisT>();
}
+ /// Query an analysis for the current ir unit of a specific derived operation
+ /// type.
+ template <typename AnalysisT, typename OpT>
+ AnalysisT &getAnalysis() {
+ return getAnalysisManager().getAnalysis<AnalysisT, OpT>();
+ }
+
/// Query a cached instance of an analysis for the current ir unit if one
/// exists.
template <typename AnalysisT>
@@ -187,12 +194,14 @@ class Pass {
getPassState().preservedAnalyses.preserve(id);
}
- /// Returns the analysis for the parent operation if it exists.
+ /// Returns the analysis for the given parent operation if it exists.
template <typename AnalysisT>
Optional<std::reference_wrapper<AnalysisT>>
getCachedParentAnalysis(Operation *parent) {
return getAnalysisManager().getCachedParentAnalysis<AnalysisT>(parent);
}
+
+ /// Returns the analysis for the parent operation if it exists.
template <typename AnalysisT>
Optional<std::reference_wrapper<AnalysisT>> getCachedParentAnalysis() {
return getAnalysisManager().getCachedParentAnalysis<AnalysisT>(
@@ -212,6 +221,13 @@ class Pass {
return getAnalysisManager().getChildAnalysis<AnalysisT>(child);
}
+ /// Returns the analysis for the given child operation of specific derived
+ /// operation type, or creates it if it doesn't exist.
+ template <typename AnalysisT, typename OpTy>
+ AnalysisT &getChildAnalysis(OpTy child) {
+ return getAnalysisManager().getChildAnalysis<AnalysisT>(child);
+ }
+
/// Returns the current analysis manager.
AnalysisManager getAnalysisManager() {
return getPassState().analysisManager;
@@ -286,6 +302,13 @@ template <typename OpT = void> class OperationPass : public Pass {
/// Return the current operation being transformed.
OpT getOperation() { return cast<OpT>(Pass::getOperation()); }
+
+ /// Query an analysis for the current operation of the specific derived
+ /// operation type.
+ template <typename AnalysisT>
+ AnalysisT &getAnalysis() {
+ return Pass::getAnalysis<AnalysisT, OpT>();
+ }
};
/// Pass to transform an operation.
diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp
index a99df3911a5d..41a90649deef 100644
--- a/mlir/unittests/Pass/AnalysisManagerTest.cpp
+++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp
@@ -9,6 +9,8 @@
#include "mlir/Pass/AnalysisManager.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
#include "gtest/gtest.h"
using namespace mlir;
@@ -22,6 +24,9 @@ struct MyAnalysis {
struct OtherAnalysis {
OtherAnalysis(Operation *) {}
};
+struct OpSpecificAnalysis {
+ OpSpecificAnalysis(ModuleOp) {}
+};
TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
MLIRContext context;
@@ -138,4 +143,18 @@ TEST(AnalysisManagerTest, CustomInvalidation) {
am.invalidate(pa);
EXPECT_TRUE(am.getCachedAnalysis<CustomInvalidatingAnalysis>().hasValue());
}
+
+TEST(AnalysisManagerTest, OpSpecificAnalysis) {
+ MLIRContext context;
+
+ // Create a module.
+ OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
+ ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
+ AnalysisManager am = mam;
+
+ // Query the op specific analysis for the module and verify that its cached.
+ am.getAnalysis<OpSpecificAnalysis, ModuleOp>();
+ EXPECT_TRUE(am.getCachedAnalysis<OpSpecificAnalysis>().hasValue());
+}
+
} // end namespace
diff --git a/mlir/unittests/Pass/CMakeLists.txt b/mlir/unittests/Pass/CMakeLists.txt
index a5aaee378f33..52cee34cee65 100644
--- a/mlir/unittests/Pass/CMakeLists.txt
+++ b/mlir/unittests/Pass/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRPassTests
AnalysisManagerTest.cpp
+ PassManagerTest.cpp
)
target_link_libraries(MLIRPassTests
PRIVATE
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
new file mode 100644
index 000000000000..29086a2994e8
--- /dev/null
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -0,0 +1,77 @@
+//===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/PassManager.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+namespace {
+/// Analysis that operates on any operation.
+struct GenericAnalysis {
+ GenericAnalysis(Operation *op) : isFunc(isa<FuncOp>(op)) {}
+ const bool isFunc;
+};
+
+/// Analysis that operates on a specific operation.
+struct OpSpecificAnalysis {
+ OpSpecificAnalysis(FuncOp op) : isSecret(op.getName() == "secret") {}
+ const bool isSecret;
+};
+
+/// Simple pass to annotate a FuncOp with the results of analysis.
+/// Note: not using FunctionPass as it skip external functions.
+struct AnnotateFunctionPass
+ : public PassWrapper<AnnotateFunctionPass, OperationPass<FuncOp>> {
+ void runOnOperation() override {
+ FuncOp op = getOperation();
+ Builder builder(op.getParentOfType<ModuleOp>());
+
+ auto &ga = getAnalysis<GenericAnalysis>();
+ auto &sa = getAnalysis<OpSpecificAnalysis>();
+
+ op.setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
+ op.setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
+ }
+};
+
+TEST(PassManagerTest, OpSpecificAnalysis) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ // Create a module with 2 functions.
+ OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
+ for (StringRef name : {"secret", "not_secret"}) {
+ FuncOp func =
+ FuncOp::create(builder.getUnknownLoc(), name,
+ builder.getFunctionType(llvm::None, llvm::None));
+ module->push_back(func);
+ }
+
+ // Instantiate and run our pass.
+ PassManager pm(&context);
+ pm.addNestedPass<FuncOp>(std::make_unique<AnnotateFunctionPass>());
+ LogicalResult result = pm.run(module.get());
+ EXPECT_TRUE(succeeded(result));
+
+ // Verify that each function got annotated with expected attributes.
+ for (FuncOp func : module->getOps<FuncOp>()) {
+ ASSERT_TRUE(func.getAttr("isFunc").isa<BoolAttr>());
+ EXPECT_TRUE(func.getAttr("isFunc").cast<BoolAttr>().getValue());
+
+ bool isSecret = func.getName() == "secret";
+ ASSERT_TRUE(func.getAttr("isSecret").isa<BoolAttr>());
+ EXPECT_EQ(func.getAttr("isSecret").cast<BoolAttr>().getValue(), isSecret);
+ }
+}
+
+} // end namespace
More information about the Mlir-commits
mailing list