[Mlir-commits] [mlir] c0b6bc0 - Decouple OpPassManager from the the MLIRContext (NFC)
Mehdi Amini
llvmlistbot at llvm.org
Wed Sep 2 23:02:16 PDT 2020
Author: Mehdi Amini
Date: 2020-09-03T06:02:05Z
New Revision: c0b6bc070e78cbd20bc4351704f52d85192e8804
URL: https://github.com/llvm/llvm-project/commit/c0b6bc070e78cbd20bc4351704f52d85192e8804
DIFF: https://github.com/llvm/llvm-project/commit/c0b6bc070e78cbd20bc4351704f52d85192e8804.diff
LOG: Decouple OpPassManager from the the MLIRContext (NFC)
This is allowing to build an OpPassManager from a StringRef instead of an
Identifier, which enables building pipelines without an MLIRContext.
An identifier is still cached on-demand on the OpPassManager for efficiency
during the IR traversal.
Added:
Modified:
mlir/include/mlir/Pass/PassManager.h
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassStatistics.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 8addd9809f90..ec88485cd3ef 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -47,7 +47,8 @@ struct OpPassManagerImpl;
/// other OpPassManagers or the top-level PassManager.
class OpPassManager {
public:
- OpPassManager(Identifier name, MLIRContext *context, bool verifyPasses);
+ OpPassManager(Identifier name, bool verifyPasses);
+ OpPassManager(StringRef name, bool verifyPasses);
OpPassManager(OpPassManager &&rhs);
OpPassManager(const OpPassManager &rhs);
~OpPassManager();
@@ -73,7 +74,7 @@ class OpPassManager {
OpPassManager &nest(Identifier nestedName);
OpPassManager &nest(StringRef nestedName);
template <typename OpT> OpPassManager &nest() {
- return nest(Identifier::get(OpT::getOperationName(), getContext()));
+ return nest(OpT::getOperationName());
}
/// Add the given pass to this pass manager. If this pass has a concrete
@@ -89,11 +90,11 @@ class OpPassManager {
/// Returns the number of passes held by this manager.
size_t size() const;
- /// Return an instance of the context.
- MLIRContext *getContext() const;
+ /// Return the operation name that this pass manager operates on.
+ Identifier getOpName(MLIRContext &context) const;
/// Return the operation name that this pass manager operates on.
- Identifier getOpName() const;
+ StringRef getOpName() const;
/// Returns the internal implementation instance.
detail::OpPassManagerImpl &getImpl();
@@ -151,6 +152,9 @@ class PassManager : public OpPassManager {
LLVM_NODISCARD
LogicalResult run(ModuleOp module);
+ /// Return an instance of the context.
+ MLIRContext *getContext() const { return context; }
+
/// Enable support for the pass manager to generate a reproducer on the event
/// of a crash or a pass failure. `outputFile` is a .mlir filename used to
/// write the generated reproducer. If `genLocalReproducer` is true, the pass
@@ -304,6 +308,8 @@ class PassManager : public OpPassManager {
runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
ModuleOp module, AnalysisManager am);
+ MLIRContext *context;
+
/// Flag that specifies if pass statistics should be dumped.
Optional<PassDisplayMode> passStatisticsMode;
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index d3cf62574afd..3ac41cde7911 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -92,8 +92,10 @@ void VerifierPass::runOnOperation() {
namespace mlir {
namespace detail {
struct OpPassManagerImpl {
- OpPassManagerImpl(Identifier name, MLIRContext *ctx, bool verifyPasses)
- : name(name), context(ctx), verifyPasses(verifyPasses) {}
+ OpPassManagerImpl(Identifier identifier, bool verifyPasses)
+ : name(identifier), identifier(identifier), verifyPasses(verifyPasses) {}
+ OpPassManagerImpl(StringRef name, bool verifyPasses)
+ : name(name), verifyPasses(verifyPasses) {}
/// Merge the passes of this pass manager into the one provided.
void mergeInto(OpPassManagerImpl &rhs);
@@ -101,9 +103,7 @@ struct OpPassManagerImpl {
/// Nest a new operation pass manager for the given operation kind under this
/// pass manager.
OpPassManager &nest(Identifier nestedName);
- OpPassManager &nest(StringRef nestedName) {
- return nest(Identifier::get(nestedName, getContext()));
- }
+ OpPassManager &nest(StringRef nestedName);
/// Add the given pass to this pass manager. If this pass has a concrete
/// operation type, it must be the same type as this pass manager.
@@ -117,14 +117,18 @@ struct OpPassManagerImpl {
/// pass.
void splitAdaptorPasses();
- /// Return an instance of the context.
- MLIRContext *getContext() const { return context; }
+ Identifier getOpName(MLIRContext &context) {
+ if (!identifier)
+ identifier = Identifier::get(name, &context);
+ return *identifier;
+ }
/// The name of the operation that passes of this pass manager operate on.
- Identifier name;
+ StringRef name;
- /// The current context for this pass manager
- MLIRContext *context;
+ /// The cached identifier (internalized in the context) for the name of the
+ /// operation that passes of this pass manager operate on.
+ Optional<Identifier> identifier;
/// Flag that specifies if the IR should be verified after each pass has run.
bool verifyPasses : 1;
@@ -143,7 +147,14 @@ void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
}
OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) {
- OpPassManager nested(nestedName, getContext(), verifyPasses);
+ OpPassManager nested(nestedName, verifyPasses);
+ auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
+ addPass(std::unique_ptr<Pass>(adaptor));
+ return adaptor->getPassManagers().front();
+}
+
+OpPassManager &OpPassManagerImpl::nest(StringRef nestedName) {
+ OpPassManager nested(nestedName, verifyPasses);
auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
addPass(std::unique_ptr<Pass>(adaptor));
return adaptor->getPassManagers().front();
@@ -153,7 +164,7 @@ void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
// If this pass runs on a
diff erent operation than this pass manager, then
// implicitly nest a pass manager for this operation.
auto passOpName = pass->getOpName();
- if (passOpName && passOpName != name.strref())
+ if (passOpName && passOpName != name)
return nest(*passOpName).addPass(std::move(pass));
passes.emplace_back(std::move(pass));
@@ -240,14 +251,14 @@ void OpPassManagerImpl::splitAdaptorPasses() {
// OpPassManager
//===----------------------------------------------------------------------===//
-OpPassManager::OpPassManager(Identifier name, MLIRContext *context,
- bool verifyPasses)
- : impl(new OpPassManagerImpl(name, context, verifyPasses)) {}
+OpPassManager::OpPassManager(Identifier name, bool verifyPasses)
+ : impl(new OpPassManagerImpl(name, verifyPasses)) {}
+OpPassManager::OpPassManager(StringRef name, bool verifyPasses)
+ : impl(new OpPassManagerImpl(name, verifyPasses)) {}
OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {}
OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
- impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->getContext(),
- rhs.impl->verifyPasses));
+ impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->verifyPasses));
for (auto &pass : rhs.impl->passes)
impl->passes.emplace_back(pass->clone());
return *this;
@@ -290,11 +301,13 @@ size_t OpPassManager::size() const { return impl->passes.size(); }
/// Returns the internal implementation instance.
OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
-/// Return an instance of the context.
-MLIRContext *OpPassManager::getContext() const { return impl->getContext(); }
+/// Return the operation name that this pass manager operates on.
+StringRef OpPassManager::getOpName() const { return impl->name; }
/// Return the operation name that this pass manager operates on.
-Identifier OpPassManager::getOpName() const { return impl->name; }
+Identifier OpPassManager::getOpName(MLIRContext &context) const {
+ return impl->getOpName(context);
+}
/// Prints out the given passes as the textual representation of a pipeline.
static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
@@ -389,12 +402,22 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
/// Find an operation pass manager that can operate on an operation of the given
/// type, or nullptr if one does not exist.
static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
- Identifier name) {
+ StringRef name) {
auto it = llvm::find_if(
mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; });
return it == mgrs.end() ? nullptr : &*it;
}
+/// Find an operation pass manager that can operate on an operation of the given
+/// type, or nullptr if one does not exist.
+static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
+ Identifier name,
+ MLIRContext &context) {
+ auto it = llvm::find_if(
+ mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; });
+ return it == mgrs.end() ? nullptr : &*it;
+}
+
OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
mgrs.emplace_back(std::move(mgr));
}
@@ -421,8 +444,7 @@ void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
// After coalescing, sort the pass managers within rhs by name.
llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(),
[](const OpPassManager *lhs, const OpPassManager *rhs) {
- return lhs->getOpName().strref().compare(
- rhs->getOpName().strref());
+ return lhs->getOpName().compare(rhs->getOpName());
});
}
@@ -454,16 +476,18 @@ void OpToOpPassAdaptor::runOnOperationImpl() {
for (auto ®ion : getOperation()->getRegions()) {
for (auto &block : region) {
for (auto &op : block) {
- auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier());
+ auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier(),
+ *op.getContext());
if (!mgr)
continue;
+ Identifier opName = mgr->getOpName(*getOperation()->getContext());
// Run the held pipeline over the current operation.
if (instrumentor)
- instrumentor->runBeforePipeline(mgr->getOpName(), parentInfo);
+ instrumentor->runBeforePipeline(opName, parentInfo);
auto result = runPipeline(mgr->getPasses(), &op, am.nest(&op));
if (instrumentor)
- instrumentor->runAfterPipeline(mgr->getOpName(), parentInfo);
+ instrumentor->runAfterPipeline(opName, parentInfo);
if (failed(result))
return signalPassFailure();
@@ -499,7 +523,8 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
for (auto &block : region) {
for (auto &op : block) {
// Add this operation iff the name matches any of the pass managers.
- if (findPassManagerFor(mgrs, op.getName().getIdentifier()))
+ if (findPassManagerFor(mgrs, op.getName().getIdentifier(),
+ getContext()))
opAMPairs.emplace_back(&op, am.nest(&op));
}
}
@@ -535,16 +560,17 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
// Get the pass manager for this operation and execute it.
auto &it = opAMPairs[nextID];
- auto *pm =
- findPassManagerFor(pms, it.first->getName().getIdentifier());
+ auto *pm = findPassManagerFor(
+ pms, it.first->getName().getIdentifier(), getContext());
assert(pm && "expected valid pass manager for operation");
+ Identifier opName = pm->getOpName(*getOperation()->getContext());
if (instrumentor)
- instrumentor->runBeforePipeline(pm->getOpName(), parentInfo);
+ instrumentor->runBeforePipeline(opName, parentInfo);
auto pipelineResult =
runPipeline(pm->getPasses(), it.first, it.second);
if (instrumentor)
- instrumentor->runAfterPipeline(pm->getOpName(), parentInfo);
+ instrumentor->runAfterPipeline(opName, parentInfo);
// Drop this thread from being tracked by the diagnostic handler.
// After this task has finished, the thread may be used outside of
@@ -737,9 +763,9 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
//===----------------------------------------------------------------------===//
PassManager::PassManager(MLIRContext *ctx, bool verifyPasses)
- : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx), ctx,
+ : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx),
verifyPasses),
- passTiming(false), localReproducer(false) {}
+ context(ctx), passTiming(false), localReproducer(false) {}
PassManager::~PassManager() {}
diff --git a/mlir/lib/Pass/PassStatistics.cpp b/mlir/lib/Pass/PassStatistics.cpp
index 3721230b6913..d909c98abf56 100644
--- a/mlir/lib/Pass/PassStatistics.cpp
+++ b/mlir/lib/Pass/PassStatistics.cpp
@@ -116,7 +116,7 @@ static void printResultsAsPipeline(raw_ostream &os, OpPassManager &pm) {
// Print each of the children passes.
for (OpPassManager &mgr : mgrs) {
- auto name = ("'" + mgr.getOpName().strref() + "' Pipeline").str();
+ auto name = ("'" + mgr.getOpName() + "' Pipeline").str();
printPassEntry(os, indent, name);
for (Pass &pass : mgr.getPasses())
printPass(indent + 2, &pass);
More information about the Mlir-commits
mailing list