[Mlir-commits] [mlir] 97c899f - [mlir] Add callback to provide a pass pipeline to MlirOptMain

Mehdi Amini llvmlistbot at llvm.org
Fri Nov 5 10:46:42 PDT 2021


Author: Deepak Panickal
Date: 2021-11-05T17:46:35Z
New Revision: 97c899f3c5d9bbff2824b3252b21378bf96f3f3f

URL: https://github.com/llvm/llvm-project/commit/97c899f3c5d9bbff2824b3252b21378bf96f3f3f
DIFF: https://github.com/llvm/llvm-project/commit/97c899f3c5d9bbff2824b3252b21378bf96f3f3f.diff

LOG: [mlir] Add callback to provide a pass pipeline to MlirOptMain

The callback can be used to provide a default pass pipeline.

Reviewed By: mehdi_amini, rriddle

Differential Revision: https://reviews.llvm.org/D113144

Added: 
    

Modified: 
    mlir/include/mlir/Support/MlirOptMain.h
    mlir/lib/Support/MlirOptMain.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Support/MlirOptMain.h b/mlir/include/mlir/Support/MlirOptMain.h
index 4ae3535d4c408..51a26d08341a4 100644
--- a/mlir/include/mlir/Support/MlirOptMain.h
+++ b/mlir/include/mlir/Support/MlirOptMain.h
@@ -27,6 +27,12 @@ class MemoryBuffer;
 namespace mlir {
 class DialectRegistry;
 class PassPipelineCLParser;
+class PassManager;
+
+/// This defines the function type used to setup the pass manager. This can be
+/// used to pass in a callback to setup a default pass pipeline to be applied on
+/// the loaded IR.
+using PassPipelineFn = llvm::function_ref<LogicalResult(PassManager &pm)>;
 
 /// Perform the core processing behind `mlir-opt`:
 /// - outputStream is the stream where the resulting IR is printed.
@@ -52,6 +58,17 @@ LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
                           bool allowUnregisteredDialects,
                           bool preloadDialectsInContext = false);
 
+/// Support a callback to setup the pass manager.
+/// - passManagerSetupFn is the callback invoked to setup the pass manager to
+///   apply on the loaded IR.
+LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
+                          std::unique_ptr<llvm::MemoryBuffer> buffer,
+                          PassPipelineFn passManagerSetupFn,
+                          DialectRegistry &registry, bool splitInputFile,
+                          bool verifyDiagnostics, bool verifyPasses,
+                          bool allowUnregisteredDialects,
+                          bool preloadDialectsInContext = false);
+
 /// Implementation for tools like `mlir-opt`.
 /// - toolName is used for the header displayed by `--help`.
 /// - registry should contain all the dialects that can be parsed in the source.

diff  --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp
index 6da9f993eeafb..9a8b21d37f254 100644
--- a/mlir/lib/Support/MlirOptMain.cpp
+++ b/mlir/lib/Support/MlirOptMain.cpp
@@ -48,7 +48,7 @@ using llvm::SMLoc;
 static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
                                     bool verifyPasses, SourceMgr &sourceMgr,
                                     MLIRContext *context,
-                                    const PassPipelineCLParser &passPipeline) {
+                                    PassPipelineFn passManagerSetupFn) {
   DefaultTimingManager tm;
   applyDefaultTimingManagerCLOptions(tm);
   TimingScope timing = tm.getRootScope();
@@ -72,13 +72,8 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
   applyPassManagerCLOptions(pm);
   pm.enableTiming(timing);
 
-  auto errorHandler = [&](const Twine &msg) {
-    emitError(UnknownLoc::get(context)) << msg;
-    return failure();
-  };
-
-  // Build the provided pipeline.
-  if (failed(passPipeline.addToPipeline(pm, errorHandler)))
+  // Callback to build the pipeline.
+  if (failed(passManagerSetupFn(pm)))
     return failure();
 
   // Run the pipeline.
@@ -98,8 +93,8 @@ static LogicalResult
 processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
               bool verifyDiagnostics, bool verifyPasses,
               bool allowUnregisteredDialects, bool preloadDialectsInContext,
-              const PassPipelineCLParser &passPipeline,
-              DialectRegistry &registry, llvm::ThreadPool &threadPool) {
+              PassPipelineFn passManagerSetupFn, DialectRegistry &registry,
+              llvm::ThreadPool &threadPool) {
   // Tell sourceMgr about this buffer, which is what the parser will pick up.
   SourceMgr sourceMgr;
   sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@@ -122,7 +117,7 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
   if (!verifyDiagnostics) {
     SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
     return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
-                          &context, passPipeline);
+                          &context, passManagerSetupFn);
   }
 
   SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
@@ -131,7 +126,7 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
   // these actions succeed or fail, we only care what diagnostics they produce
   // and whether they match our expectations.
   (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
-                       passPipeline);
+                       passManagerSetupFn);
 
   // Verify the diagnostic handler to make sure that each of the diagnostics
   // matched.
@@ -140,7 +135,7 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
 
 LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
                                 std::unique_ptr<MemoryBuffer> buffer,
-                                const PassPipelineCLParser &passPipeline,
+                                PassPipelineFn passManagerSetupFn,
                                 DialectRegistry &registry, bool splitInputFile,
                                 bool verifyDiagnostics, bool verifyPasses,
                                 bool allowUnregisteredDialects,
@@ -156,17 +151,36 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
         [&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) {
           return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
                                verifyPasses, allowUnregisteredDialects,
-                               preloadDialectsInContext, passPipeline, registry,
-                               threadPool);
+                               preloadDialectsInContext, passManagerSetupFn,
+                               registry, threadPool);
         },
         outputStream);
 
   return processBuffer(outputStream, std::move(buffer), verifyDiagnostics,
                        verifyPasses, allowUnregisteredDialects,
-                       preloadDialectsInContext, passPipeline, registry,
+                       preloadDialectsInContext, passManagerSetupFn, registry,
                        threadPool);
 }
 
+LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
+                                std::unique_ptr<MemoryBuffer> buffer,
+                                const PassPipelineCLParser &passPipeline,
+                                DialectRegistry &registry, bool splitInputFile,
+                                bool verifyDiagnostics, bool verifyPasses,
+                                bool allowUnregisteredDialects,
+                                bool preloadDialectsInContext) {
+  auto passManagerSetupFn = [&](PassManager &pm) {
+    auto errorHandler = [&](const Twine &msg) {
+      emitError(UnknownLoc::get(pm.getContext())) << msg;
+      return failure();
+    };
+    return passPipeline.addToPipeline(pm, errorHandler);
+  };
+  return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn,
+                     registry, splitInputFile, verifyDiagnostics, verifyPasses,
+                     allowUnregisteredDialects, preloadDialectsInContext);
+}
+
 LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
                                 DialectRegistry &registry,
                                 bool preloadDialectsInContext) {


        


More information about the Mlir-commits mailing list