[Mlir-commits] [mlir] b1aaed0 - Enable `Pass::initialize()` to fail by returning a LogicalResult
Mehdi Amini
llvmlistbot at llvm.org
Wed Feb 10 17:52:05 PST 2021
Author: Mehdi Amini
Date: 2021-02-11T01:51:53Z
New Revision: b1aaed023e98ee9989532c8d3914c3bec7bbf964
URL: https://github.com/llvm/llvm-project/commit/b1aaed023e98ee9989532c8d3914c3bec7bbf964
DIFF: https://github.com/llvm/llvm-project/commit/b1aaed023e98ee9989532c8d3914c3bec7bbf964.diff
LOG: Enable `Pass::initialize()` to fail by returning a LogicalResult
Differential Revision: https://reviews.llvm.org/D96474
Added:
Modified:
mlir/docs/PassManagement.md
mlir/include/mlir/Pass/Pass.h
mlir/include/mlir/Pass/PassManager.h
mlir/lib/Pass/Pass.cpp
mlir/lib/Transforms/Canonicalizer.cpp
Removed:
################################################################################
diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index 8fc7ff2f0ae1..7588f4fba1b5 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -140,13 +140,15 @@ such example is when using [`PDL`-based](Dialects/PDLOps.md)
runtime. In these situations, a pass may override the following hook to
initialize this heavy state:
-* `void initialize(MLIRContext *context)`
+* `LogicalResult initialize(MLIRContext *context)`
This hook is executed once per run of a full pass pipeline, meaning that it does
not have access to the state available during a `runOnOperation` call. More
concretely, all necessary accesses to an `MLIRContext` should be driven via the
provided `context` parameter, and methods that utilize "per-run" state such as
`getContext`/`getOperation`/`getAnalysis`/etc. must not be used.
+In case of an error during initialization, the pass is expected to emit an error
+diagnostic and return a `failure()` which will abort the pass pipeline execution.
## Analysis Management
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 3eeaf79b9610..10b1a5c57a81 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -170,7 +170,9 @@ class Pass {
/// should not rely on any state accessible during the execution of a pass.
/// For example, `getContext`/`getOperation`/`getAnalysis`/etc. should not be
/// invoked within this hook.
- virtual void initialize(MLIRContext *context) {}
+ /// Returns a LogicalResult to indicate failure, in which case the pass
+ /// pipeline won't execute.
+ virtual LogicalResult initialize(MLIRContext *context) { return success(); }
/// Schedule an arbitrary pass pipeline on the provided operation.
/// This can be invoke any time in a pass to dynamic schedule more passes.
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 99798bdd551d..b6e9e2766773 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -130,7 +130,7 @@ class OpPassManager {
/// Initialize all of the passes within this pass manager with the given
/// initialization generation. The initialization generation is used to detect
/// if a pass manager has already been initialized.
- void initialize(MLIRContext *context, unsigned newInitGeneration);
+ LogicalResult initialize(MLIRContext *context, unsigned newInitGeneration);
/// A pointer to an internal implementation instance.
std::unique_ptr<detail::OpPassManagerImpl> impl;
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index f4779cb860a0..8507fd6d3451 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -331,23 +331,26 @@ void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
-void OpPassManager::initialize(MLIRContext *context,
- unsigned newInitGeneration) {
+LogicalResult OpPassManager::initialize(MLIRContext *context,
+ unsigned newInitGeneration) {
if (impl->initializationGeneration == newInitGeneration)
- return;
+ return success();
impl->initializationGeneration = newInitGeneration;
for (Pass &pass : getPasses()) {
// If this pass isn't an adaptor, directly initialize it.
auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
if (!adaptor) {
- pass.initialize(context);
+ if (failed(pass.initialize(context)))
+ return failure();
continue;
}
// Otherwise, initialize each of the adaptors pass managers.
for (OpPassManager &adaptorPM : adaptor->getPassManagers())
- adaptorPM.initialize(context, newInitGeneration);
+ if (failed(adaptorPM.initialize(context, newInitGeneration)))
+ return failure();
}
+ return success();
}
//===----------------------------------------------------------------------===//
@@ -379,7 +382,8 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
assert(pipeline.getOpName() == root->getName().getStringRef());
// Initialize the user provided pipeline and execute the pipeline.
- pipeline.initialize(root->getContext(), parentInitGeneration);
+ if (failed(pipeline.initialize(root->getContext(), parentInitGeneration)))
+ return failure();
AnalysisManager nestedAm = root == op ? am : am.nest(root);
return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm,
verifyPasses, parentInitGeneration,
@@ -872,7 +876,8 @@ LogicalResult PassManager::run(Operation *op) {
// Initialize all of the passes within the pass manager with a new generation.
llvm::hash_code newInitKey = context->getRegistryHash();
if (newInitKey != initializationKey) {
- initialize(context, impl->initializationGeneration + 1);
+ if (failed(initialize(context, impl->initializationGeneration + 1)))
+ return failure();
initializationKey = newInitKey;
}
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index cf0b21735362..ba6ffe72a248 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -23,11 +23,12 @@ namespace {
struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
/// Initialize the canonicalizer by building the set of patterns used during
/// execution.
- void initialize(MLIRContext *context) override {
+ LogicalResult initialize(MLIRContext *context) override {
OwningRewritePatternList owningPatterns;
for (auto *op : context->getRegisteredOperations())
op->getCanonicalizationPatterns(owningPatterns, context);
patterns = std::move(owningPatterns);
+ return success();
}
void runOnOperation() override {
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns);
More information about the Mlir-commits
mailing list