[Mlir-commits] [mlir] 97c899f - [mlir] Add callback to provide a pass pipeline to MlirOptMain
Mehdi Amini
llvmlistbot at llvm.org
Fri Nov 5 10:46:42 PDT 2021
Author: Deepak Panickal
Date: 2021-11-05T17:46:35Z
New Revision: 97c899f3c5d9bbff2824b3252b21378bf96f3f3f
URL: https://github.com/llvm/llvm-project/commit/97c899f3c5d9bbff2824b3252b21378bf96f3f3f
DIFF: https://github.com/llvm/llvm-project/commit/97c899f3c5d9bbff2824b3252b21378bf96f3f3f.diff
LOG: [mlir] Add callback to provide a pass pipeline to MlirOptMain
The callback can be used to provide a default pass pipeline.
Reviewed By: mehdi_amini, rriddle
Differential Revision: https://reviews.llvm.org/D113144
Added:
Modified:
mlir/include/mlir/Support/MlirOptMain.h
mlir/lib/Support/MlirOptMain.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Support/MlirOptMain.h b/mlir/include/mlir/Support/MlirOptMain.h
index 4ae3535d4c408..51a26d08341a4 100644
--- a/mlir/include/mlir/Support/MlirOptMain.h
+++ b/mlir/include/mlir/Support/MlirOptMain.h
@@ -27,6 +27,12 @@ class MemoryBuffer;
namespace mlir {
class DialectRegistry;
class PassPipelineCLParser;
+class PassManager;
+
+/// This defines the function type used to setup the pass manager. This can be
+/// used to pass in a callback to setup a default pass pipeline to be applied on
+/// the loaded IR.
+using PassPipelineFn = llvm::function_ref<LogicalResult(PassManager &pm)>;
/// Perform the core processing behind `mlir-opt`:
/// - outputStream is the stream where the resulting IR is printed.
@@ -52,6 +58,17 @@ LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
bool allowUnregisteredDialects,
bool preloadDialectsInContext = false);
+/// Support a callback to setup the pass manager.
+/// - passManagerSetupFn is the callback invoked to setup the pass manager to
+/// apply on the loaded IR.
+LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
+ std::unique_ptr<llvm::MemoryBuffer> buffer,
+ PassPipelineFn passManagerSetupFn,
+ DialectRegistry ®istry, bool splitInputFile,
+ bool verifyDiagnostics, bool verifyPasses,
+ bool allowUnregisteredDialects,
+ bool preloadDialectsInContext = false);
+
/// Implementation for tools like `mlir-opt`.
/// - toolName is used for the header displayed by `--help`.
/// - registry should contain all the dialects that can be parsed in the source.
diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp
index 6da9f993eeafb..9a8b21d37f254 100644
--- a/mlir/lib/Support/MlirOptMain.cpp
+++ b/mlir/lib/Support/MlirOptMain.cpp
@@ -48,7 +48,7 @@ using llvm::SMLoc;
static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
bool verifyPasses, SourceMgr &sourceMgr,
MLIRContext *context,
- const PassPipelineCLParser &passPipeline) {
+ PassPipelineFn passManagerSetupFn) {
DefaultTimingManager tm;
applyDefaultTimingManagerCLOptions(tm);
TimingScope timing = tm.getRootScope();
@@ -72,13 +72,8 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
applyPassManagerCLOptions(pm);
pm.enableTiming(timing);
- auto errorHandler = [&](const Twine &msg) {
- emitError(UnknownLoc::get(context)) << msg;
- return failure();
- };
-
- // Build the provided pipeline.
- if (failed(passPipeline.addToPipeline(pm, errorHandler)))
+ // Callback to build the pipeline.
+ if (failed(passManagerSetupFn(pm)))
return failure();
// Run the pipeline.
@@ -98,8 +93,8 @@ static LogicalResult
processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects, bool preloadDialectsInContext,
- const PassPipelineCLParser &passPipeline,
- DialectRegistry ®istry, llvm::ThreadPool &threadPool) {
+ PassPipelineFn passManagerSetupFn, DialectRegistry ®istry,
+ llvm::ThreadPool &threadPool) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@@ -122,7 +117,7 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
if (!verifyDiagnostics) {
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
- &context, passPipeline);
+ &context, passManagerSetupFn);
}
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
@@ -131,7 +126,7 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
// these actions succeed or fail, we only care what diagnostics they produce
// and whether they match our expectations.
(void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
- passPipeline);
+ passManagerSetupFn);
// Verify the diagnostic handler to make sure that each of the diagnostics
// matched.
@@ -140,7 +135,7 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
std::unique_ptr<MemoryBuffer> buffer,
- const PassPipelineCLParser &passPipeline,
+ PassPipelineFn passManagerSetupFn,
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
@@ -156,17 +151,36 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
[&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
- preloadDialectsInContext, passPipeline, registry,
- threadPool);
+ preloadDialectsInContext, passManagerSetupFn,
+ registry, threadPool);
},
outputStream);
return processBuffer(outputStream, std::move(buffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
- preloadDialectsInContext, passPipeline, registry,
+ preloadDialectsInContext, passManagerSetupFn, registry,
threadPool);
}
+LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
+ std::unique_ptr<MemoryBuffer> buffer,
+ const PassPipelineCLParser &passPipeline,
+ DialectRegistry ®istry, bool splitInputFile,
+ bool verifyDiagnostics, bool verifyPasses,
+ bool allowUnregisteredDialects,
+ bool preloadDialectsInContext) {
+ auto passManagerSetupFn = [&](PassManager &pm) {
+ auto errorHandler = [&](const Twine &msg) {
+ emitError(UnknownLoc::get(pm.getContext())) << msg;
+ return failure();
+ };
+ return passPipeline.addToPipeline(pm, errorHandler);
+ };
+ return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn,
+ registry, splitInputFile, verifyDiagnostics, verifyPasses,
+ allowUnregisteredDialects, preloadDialectsInContext);
+}
+
LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
DialectRegistry ®istry,
bool preloadDialectsInContext) {
More information about the Mlir-commits
mailing list