[Mlir-commits] [mlir] 8d2e0c1 - Refactor a MlirOptMainConfig class to hold the configuration of MlirOptMain (NFC)

Mehdi Amini llvmlistbot at llvm.org
Sat Feb 25 14:04:37 PST 2023


Author: Mehdi Amini
Date: 2023-02-25T17:04:29-05:00
New Revision: 8d2e0c16e691797f11a713a5601dc01e50d8734a

URL: https://github.com/llvm/llvm-project/commit/8d2e0c16e691797f11a713a5601dc01e50d8734a
DIFF: https://github.com/llvm/llvm-project/commit/8d2e0c16e691797f11a713a5601dc01e50d8734a.diff

LOG: Refactor a MlirOptMainConfig class to hold the configuration of MlirOptMain (NFC)

The list of boolean flags and others is becoming unresonnably long.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D143829

Added: 
    

Modified: 
    mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
    mlir/lib/Tools/mlir-opt/MlirOptMain.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 34f506273eff3..70cb2957e7182 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -17,6 +17,7 @@
 #include "llvm/ADT/StringRef.h"
 
 #include <cstdlib>
+#include <functional>
 #include <memory>
 
 namespace llvm {
@@ -29,31 +30,153 @@ class DialectRegistry;
 class PassPipelineCLParser;
 class PassManager;
 
+/// Configuration options for the mlir-opt tool.
+/// This is intended to help building tools like mlir-opt by collecting the
+/// supported options.
+/// The API is fluent, and the options are sorted in alphabetical order below.
+class MlirOptMainConfig {
+public:
+  /// Allow operation with no registered dialects.
+  /// This option is for convenience during testing only and discouraged in
+  /// general.
+  MlirOptMainConfig &allowUnregisteredDialects(bool allow) {
+    allowUnregisteredDialectsFlag = allow;
+    return *this;
+  }
+  bool shouldAllowUnregisteredDialects() const {
+    return allowUnregisteredDialectsFlag;
+  }
+
+  /// Print the pass-pipeline as text before executing.
+  MlirOptMainConfig &dumpPassPipeline(bool dump) {
+    dumpPassPipelineFlag = dump;
+    return *this;
+  }
+  bool shouldDumpPassPipeline() const { return dumpPassPipelineFlag; }
+
+  /// Set the output format to bytecode instead of textual IR.
+  MlirOptMainConfig &emitBytecode(bool emit) {
+    emitBytecodeFlag = emit;
+    return *this;
+  }
+  bool shouldEmitBytecode() const { return emitBytecodeFlag; }
+
+  /// Set the callback to populate the pass manager.
+  MlirOptMainConfig &
+  setPassPipelineSetupFn(std::function<LogicalResult(PassManager &)> callback) {
+    passPipelineCallback = std::move(callback);
+    return *this;
+  }
+
+  /// Set the parser to use to populate the pass manager.
+  MlirOptMainConfig &setPassPipelineParser(const PassPipelineCLParser &parser);
+
+  /// Populate the passmanager, if any callback was set.
+  LogicalResult setupPassPipeline(PassManager &pm) const {
+    if (passPipelineCallback)
+      return passPipelineCallback(pm);
+    return success();
+  }
+
+  // Deprecated.
+  MlirOptMainConfig &preloadDialectsInContext(bool preload) {
+    preloadDialectsInContextFlag = preload;
+    return *this;
+  }
+  bool shouldPreloadDialectsInContext() const {
+    return preloadDialectsInContextFlag;
+  }
+
+  /// Show the registered dialects before trying to load the input file.
+  MlirOptMainConfig &showDialects(bool show) {
+    showDialectsFlag = show;
+    return *this;
+  }
+  bool shouldShowDialects() const { return showDialectsFlag; }
+
+  /// Set whether to split the input file based on the `// -----` marker into
+  /// pieces and process each chunk independently.
+  MlirOptMainConfig &splitInputFile(bool split = true) {
+    splitInputFileFlag = split;
+    return *this;
+  }
+  bool shouldSplitInputFile() const { return splitInputFileFlag; }
+
+  /// Disable implicit addition of a top-level module op during parsing.
+  MlirOptMainConfig &useImplicitModule(bool useImplicitModule) {
+    useImplicitModuleFlag = useImplicitModule;
+    return *this;
+  }
+  bool shouldUseImplicitModule() const { return useImplicitModuleFlag; }
+
+  /// Set whether to check that emitted diagnostics match `expected-*` lines on
+  /// the corresponding line. This is meant for implementing diagnostic tests.
+  MlirOptMainConfig &verifyDiagnostics(bool verify) {
+    verifyDiagnosticsFlag = verify;
+    return *this;
+  }
+  bool shouldVerifyDiagnostics() const { return verifyDiagnosticsFlag; }
+
+  /// Set whether to run the verifier after each transformation pass.
+  MlirOptMainConfig &verifyPasses(bool verify) {
+    verifyPassesFlag = verify;
+    return *this;
+  }
+  bool shouldVerifyPasses() const { return verifyPassesFlag; }
+
+private:
+  /// Allow operation with no registered dialects.
+  /// This option is for convenience during testing only and discouraged in
+  /// general.
+  bool allowUnregisteredDialectsFlag = false;
+
+  /// Print the pipeline that will be run.
+  bool dumpPassPipelineFlag = false;
+
+  /// Emit bytecode instead of textual assembly when generating output.
+  bool emitBytecodeFlag = false;
+
+  /// The callback to populate the pass manager.
+  std::function<LogicalResult(PassManager &)> passPipelineCallback;
+
+  /// Deprecated.
+  bool preloadDialectsInContextFlag = false;
+
+  /// Show the registered dialects before trying to load the input file.
+  bool showDialectsFlag = false;
+
+  /// Split the input file based on the `// -----` marker into pieces and
+  /// process each chunk independently.
+  bool splitInputFileFlag = false;
+
+  /// Use an implicit top-level module op during parsing.
+  bool useImplicitModuleFlag = true;
+
+  /// Set whether to check that emitted diagnostics match `expected-*` lines on
+  /// the corresponding line. This is meant for implementing diagnostic tests.
+  bool verifyDiagnosticsFlag = false;
+
+  /// Run the verifier after each transformation pass.
+  bool verifyPassesFlag = true;
+};
+
 /// This defines the function type used to setup the pass manager. This can be
 /// used to pass in a callback to setup a default pass pipeline to be applied on
 /// the loaded IR.
 using PassPipelineFn = llvm::function_ref<LogicalResult(PassManager &pm)>;
 
-/// Perform the core processing behind `mlir-opt`:
+/// Perform the core processing behind `mlir-opt`.
 /// - outputStream is the stream where the resulting IR is printed.
 /// - buffer is the in-memory file to parser and process.
-/// - passPipeline is the specification of the pipeline that will be applied.
 /// - registry should contain all the dialects that can be parsed in the source.
-/// - splitInputFile will look for a "-----" marker in the input file, and load
-/// each chunk in an individual ModuleOp processed separately.
-/// - verifyDiagnostics enables a verification mode where comments starting with
-/// "expected-(error|note|remark|warning)" are parsed in the input and matched
-/// against emitted diagnostics.
-/// - verifyPasses enables the IR verifier in-between each pass in the pipeline.
-/// - allowUnregisteredDialects allows to parse and create operation without
-/// registering the Dialect in the MLIRContext.
-/// - preloadDialectsInContext will trigger the upfront loading of all
-///   dialects from the global registry in the MLIRContext. This option is
-///   deprecated and will be removed soon.
-/// - emitBytecode will generate bytecode output instead of text.
-/// - implicitModule will enable implicit addition of a top-level
-/// 'builtin.module' if one doesn't already exist.
-/// - dumpPassPipeline will dump the pipeline being run to stderr
+/// - config contains the configuration options for the tool.
+LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
+                          std::unique_ptr<llvm::MemoryBuffer> buffer,
+                          DialectRegistry &registry,
+                          const MlirOptMainConfig &config);
+
+/// Perform the core processing behind `mlir-opt`.
+/// This API is deprecated, use the MlirOptMainConfig version above instead.
 LogicalResult
 MlirOptMain(llvm::raw_ostream &outputStream,
             std::unique_ptr<llvm::MemoryBuffer> buffer,
@@ -63,9 +186,8 @@ MlirOptMain(llvm::raw_ostream &outputStream,
             bool preloadDialectsInContext = false, bool emitBytecode = false,
             bool implicitModule = false, bool dumpPassPipeline = false);
 
-/// Support a callback to setup the pass manager.
-/// - passManagerSetupFn is the callback invoked to setup the pass manager to
-///   apply on the loaded IR.
+/// Perform the core processing behind `mlir-opt`.
+/// This API is deprecated, use the MlirOptMainConfig version above instead.
 LogicalResult MlirOptMain(
     llvm::raw_ostream &outputStream, std::unique_ptr<llvm::MemoryBuffer> buffer,
     PassPipelineFn passManagerSetupFn, DialectRegistry &registry,

diff  --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 60b5a42eb0eb1..75eefb19991ab 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -40,6 +40,25 @@
 using namespace mlir;
 using namespace llvm;
 
+MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser(
+    const PassPipelineCLParser &passPipeline) {
+  passPipelineCallback = [&](PassManager &pm) {
+    auto errorHandler = [&](const Twine &msg) {
+      emitError(UnknownLoc::get(pm.getContext())) << msg;
+      return failure();
+    };
+    if (failed(passPipeline.addToPipeline(pm, errorHandler)))
+      return failure();
+    if (this->shouldDumpPassPipeline()) {
+
+      pm.dump();
+      llvm::errs() << "\n";
+    }
+    return success();
+  };
+  return *this;
+}
+
 /// Perform the actions on the input file indicated by the command line flags
 /// within the specified context.
 ///
@@ -47,10 +66,9 @@ using namespace llvm;
 /// passes, then prints the output.
 ///
 static LogicalResult
-performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
+performActions(raw_ostream &os,
                const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
-               MLIRContext *context, PassPipelineFn passManagerSetupFn,
-               bool emitBytecode, bool implicitModule) {
+               MLIRContext *context, const MlirOptMainConfig &config) {
   DefaultTimingManager tm;
   applyDefaultTimingManagerCLOptions(tm);
   TimingScope timing = tm.getRootScope();
@@ -66,13 +84,14 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
   // untouched.
   PassReproducerOptions reproOptions;
   FallbackAsmResourceMap fallbackResourceMap;
-  ParserConfig config(context, /*verifyAfterParse=*/true, &fallbackResourceMap);
-  reproOptions.attachResourceParser(config);
+  ParserConfig parseConfig(context, /*verifyAfterParse=*/true,
+                           &fallbackResourceMap);
+  reproOptions.attachResourceParser(parseConfig);
 
   // Parse the input file and reset the context threading state.
   TimingScope parserTiming = timing.nest("Parser");
-  OwningOpRef<Operation *> op =
-      parseSourceFileForTool(sourceMgr, config, implicitModule);
+  OwningOpRef<Operation *> op = parseSourceFileForTool(
+      sourceMgr, parseConfig, config.shouldUseImplicitModule());
   context->enableMultithreading(wasThreadingEnabled);
   if (!op)
     return failure();
@@ -80,10 +99,10 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
 
   // Prepare the pass manager, applying command-line and reproducer options.
   PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
-  pm.enableVerifier(verifyPasses);
+  pm.enableVerifier(config.shouldVerifyPasses());
   applyPassManagerCLOptions(pm);
   pm.enableTiming(timing);
-  if (failed(reproOptions.apply(pm)) || failed(passManagerSetupFn(pm)))
+  if (failed(reproOptions.apply(pm)) || failed(config.setupPassPipeline(pm)))
     return failure();
 
   // Run the pipeline.
@@ -92,7 +111,7 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
 
   // Print the output.
   TimingScope outputTiming = timing.nest("Output");
-  if (emitBytecode) {
+  if (config.shouldEmitBytecode()) {
     BytecodeWriterConfig writerConfig(fallbackResourceMap);
     writeBytecodeToFile(op.get(), os, writerConfig);
   } else {
@@ -106,13 +125,11 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
 
 /// Parses the memory buffer.  If successfully, run a series of passes against
 /// it and print the result.
-static LogicalResult
-processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
-              bool verifyDiagnostics, bool verifyPasses,
-              bool allowUnregisteredDialects, bool preloadDialectsInContext,
-              bool emitBytecode, bool implicitModule,
-              PassPipelineFn passManagerSetupFn, DialectRegistry &registry,
-              llvm::ThreadPool *threadPool) {
+static LogicalResult processBuffer(raw_ostream &os,
+                                   std::unique_ptr<MemoryBuffer> ownedBuffer,
+                                   const MlirOptMainConfig &config,
+                                   DialectRegistry &registry,
+                                   llvm::ThreadPool *threadPool) {
   // Tell sourceMgr about this buffer, which is what the parser will pick up.
   auto sourceMgr = std::make_shared<SourceMgr>();
   sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@@ -124,20 +141,18 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
     context.setThreadPool(*threadPool);
 
   // Parse the input file.
-  if (preloadDialectsInContext)
+  if (config.shouldPreloadDialectsInContext())
     context.loadAllAvailableDialects();
-  context.allowUnregisteredDialects(allowUnregisteredDialects);
-  if (verifyDiagnostics)
+  context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects());
+  if (config.shouldVerifyDiagnostics())
     context.printOpOnDiagnostic(false);
   context.getDebugActionManager().registerActionHandler<DebugCounter>();
 
   // If we are in verify diagnostics mode then we have a lot of work to do,
   // otherwise just perform the actions without worrying about it.
-  if (!verifyDiagnostics) {
+  if (!config.shouldVerifyDiagnostics()) {
     SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
-    return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
-                          &context, passManagerSetupFn, emitBytecode,
-                          implicitModule);
+    return performActions(os, sourceMgr, &context, config);
   }
 
   SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);
@@ -145,22 +160,17 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
   // Do any processing requested by command line flags.  We don't care whether
   // these actions succeed or fail, we only care what diagnostics they produce
   // and whether they match our expectations.
-  (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
-                       passManagerSetupFn, emitBytecode, implicitModule);
+  (void)performActions(os, sourceMgr, &context, config);
 
   // Verify the diagnostic handler to make sure that each of the diagnostics
   // matched.
   return sourceMgrHandler.verify();
 }
 
-LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
-                                std::unique_ptr<MemoryBuffer> buffer,
-                                PassPipelineFn passManagerSetupFn,
-                                DialectRegistry &registry, bool splitInputFile,
-                                bool verifyDiagnostics, bool verifyPasses,
-                                bool allowUnregisteredDialects,
-                                bool preloadDialectsInContext,
-                                bool emitBytecode, bool implicitModule) {
+LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
+                                std::unique_ptr<llvm::MemoryBuffer> buffer,
+                                DialectRegistry &registry,
+                                const MlirOptMainConfig &config) {
   // The split-input-file mode is a very specific mode that slices the file
   // up into small pieces and checks each independently.
   // We use an explicit threadpool to avoid creating and joining/destroying
@@ -177,13 +187,32 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
 
   auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer,
                      raw_ostream &os) {
-    return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
-                         verifyPasses, allowUnregisteredDialects,
-                         preloadDialectsInContext, emitBytecode, implicitModule,
-                         passManagerSetupFn, registry, threadPool);
+    return processBuffer(os, std::move(chunkBuffer), config, registry,
+                         threadPool);
   };
   return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
-                               splitInputFile, /*insertMarkerInOutput=*/true);
+                               config.shouldSplitInputFile(),
+                               /*insertMarkerInOutput=*/true);
+}
+
+LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
+                                std::unique_ptr<MemoryBuffer> buffer,
+                                PassPipelineFn passManagerSetupFn,
+                                DialectRegistry &registry, bool splitInputFile,
+                                bool verifyDiagnostics, bool verifyPasses,
+                                bool allowUnregisteredDialects,
+                                bool preloadDialectsInContext,
+                                bool emitBytecode, bool implicitModule) {
+  return MlirOptMain(outputStream, std::move(buffer), registry,
+                     MlirOptMainConfig{}
+                         .splitInputFile(splitInputFile)
+                         .verifyDiagnostics(verifyDiagnostics)
+                         .verifyPasses(verifyPasses)
+                         .allowUnregisteredDialects(allowUnregisteredDialects)
+                         .preloadDialectsInContext(preloadDialectsInContext)
+                         .emitBytecode(emitBytecode)
+                         .useImplicitModule(implicitModule)
+                         .setPassPipelineSetupFn(passManagerSetupFn));
 }
 
 LogicalResult mlir::MlirOptMain(
@@ -192,23 +221,17 @@ LogicalResult mlir::MlirOptMain(
     bool splitInputFile, bool verifyDiagnostics, bool verifyPasses,
     bool allowUnregisteredDialects, bool preloadDialectsInContext,
     bool emitBytecode, bool implicitModule, bool dumpPassPipeline) {
-  auto passManagerSetupFn = [&](PassManager &pm) {
-    auto errorHandler = [&](const Twine &msg) {
-      emitError(UnknownLoc::get(pm.getContext())) << msg;
-      return failure();
-    };
-    if (failed(passPipeline.addToPipeline(pm, errorHandler)))
-      return failure();
-    if (dumpPassPipeline) {
-      pm.dump();
-      llvm::errs() << "\n";
-    }
-    return success();
-  };
-  return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn,
-                     registry, splitInputFile, verifyDiagnostics, verifyPasses,
-                     allowUnregisteredDialects, preloadDialectsInContext,
-                     emitBytecode, implicitModule);
+  return MlirOptMain(outputStream, std::move(buffer), registry,
+                     MlirOptMainConfig{}
+                         .splitInputFile(splitInputFile)
+                         .verifyDiagnostics(verifyDiagnostics)
+                         .verifyPasses(verifyPasses)
+                         .allowUnregisteredDialects(allowUnregisteredDialects)
+                         .preloadDialectsInContext(preloadDialectsInContext)
+                         .emitBytecode(emitBytecode)
+                         .useImplicitModule(implicitModule)
+                         .dumpPassPipeline(dumpPassPipeline)
+                         .setPassPipelineParser(passPipeline));
 }
 
 LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
@@ -301,12 +324,19 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
     llvm::errs() << errorMessage << "\n";
     return failure();
   }
-
-  if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
-                         splitInputFile, verifyDiagnostics, verifyPasses,
-                         allowUnregisteredDialects, preloadDialectsInContext,
-                         emitBytecode, /*implicitModule=*/!noImplicitModule,
-                         dumpPassPipeline)))
+  // Setup the configuration for the main function.
+  MlirOptMainConfig config;
+  config.setPassPipelineParser(passPipeline)
+      .splitInputFile(splitInputFile)
+      .verifyDiagnostics(verifyDiagnostics)
+      .verifyPasses(verifyPasses)
+      .allowUnregisteredDialects(allowUnregisteredDialects)
+      .preloadDialectsInContext(preloadDialectsInContext)
+      .emitBytecode(emitBytecode)
+      .useImplicitModule(!noImplicitModule)
+      .dumpPassPipeline(dumpPassPipeline);
+
+  if (failed(MlirOptMain(output->os(), std::move(file), registry, config)))
     return failure();
 
   // Keep the output file if the invocation of MlirOptMain was successful.


        


More information about the Mlir-commits mailing list