[Mlir-commits] [mlir] 8d2e0c1 - Refactor a MlirOptMainConfig class to hold the configuration of MlirOptMain (NFC)
Mehdi Amini
llvmlistbot at llvm.org
Sat Feb 25 14:04:37 PST 2023
Author: Mehdi Amini
Date: 2023-02-25T17:04:29-05:00
New Revision: 8d2e0c16e691797f11a713a5601dc01e50d8734a
URL: https://github.com/llvm/llvm-project/commit/8d2e0c16e691797f11a713a5601dc01e50d8734a
DIFF: https://github.com/llvm/llvm-project/commit/8d2e0c16e691797f11a713a5601dc01e50d8734a.diff
LOG: Refactor a MlirOptMainConfig class to hold the configuration of MlirOptMain (NFC)
The list of boolean flags and others is becoming unresonnably long.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D143829
Added:
Modified:
mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 34f506273eff3..70cb2957e7182 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -17,6 +17,7 @@
#include "llvm/ADT/StringRef.h"
#include <cstdlib>
+#include <functional>
#include <memory>
namespace llvm {
@@ -29,31 +30,153 @@ class DialectRegistry;
class PassPipelineCLParser;
class PassManager;
+/// Configuration options for the mlir-opt tool.
+/// This is intended to help building tools like mlir-opt by collecting the
+/// supported options.
+/// The API is fluent, and the options are sorted in alphabetical order below.
+class MlirOptMainConfig {
+public:
+ /// Allow operation with no registered dialects.
+ /// This option is for convenience during testing only and discouraged in
+ /// general.
+ MlirOptMainConfig &allowUnregisteredDialects(bool allow) {
+ allowUnregisteredDialectsFlag = allow;
+ return *this;
+ }
+ bool shouldAllowUnregisteredDialects() const {
+ return allowUnregisteredDialectsFlag;
+ }
+
+ /// Print the pass-pipeline as text before executing.
+ MlirOptMainConfig &dumpPassPipeline(bool dump) {
+ dumpPassPipelineFlag = dump;
+ return *this;
+ }
+ bool shouldDumpPassPipeline() const { return dumpPassPipelineFlag; }
+
+ /// Set the output format to bytecode instead of textual IR.
+ MlirOptMainConfig &emitBytecode(bool emit) {
+ emitBytecodeFlag = emit;
+ return *this;
+ }
+ bool shouldEmitBytecode() const { return emitBytecodeFlag; }
+
+ /// Set the callback to populate the pass manager.
+ MlirOptMainConfig &
+ setPassPipelineSetupFn(std::function<LogicalResult(PassManager &)> callback) {
+ passPipelineCallback = std::move(callback);
+ return *this;
+ }
+
+ /// Set the parser to use to populate the pass manager.
+ MlirOptMainConfig &setPassPipelineParser(const PassPipelineCLParser &parser);
+
+ /// Populate the passmanager, if any callback was set.
+ LogicalResult setupPassPipeline(PassManager &pm) const {
+ if (passPipelineCallback)
+ return passPipelineCallback(pm);
+ return success();
+ }
+
+ // Deprecated.
+ MlirOptMainConfig &preloadDialectsInContext(bool preload) {
+ preloadDialectsInContextFlag = preload;
+ return *this;
+ }
+ bool shouldPreloadDialectsInContext() const {
+ return preloadDialectsInContextFlag;
+ }
+
+ /// Show the registered dialects before trying to load the input file.
+ MlirOptMainConfig &showDialects(bool show) {
+ showDialectsFlag = show;
+ return *this;
+ }
+ bool shouldShowDialects() const { return showDialectsFlag; }
+
+ /// Set whether to split the input file based on the `// -----` marker into
+ /// pieces and process each chunk independently.
+ MlirOptMainConfig &splitInputFile(bool split = true) {
+ splitInputFileFlag = split;
+ return *this;
+ }
+ bool shouldSplitInputFile() const { return splitInputFileFlag; }
+
+ /// Disable implicit addition of a top-level module op during parsing.
+ MlirOptMainConfig &useImplicitModule(bool useImplicitModule) {
+ useImplicitModuleFlag = useImplicitModule;
+ return *this;
+ }
+ bool shouldUseImplicitModule() const { return useImplicitModuleFlag; }
+
+ /// Set whether to check that emitted diagnostics match `expected-*` lines on
+ /// the corresponding line. This is meant for implementing diagnostic tests.
+ MlirOptMainConfig &verifyDiagnostics(bool verify) {
+ verifyDiagnosticsFlag = verify;
+ return *this;
+ }
+ bool shouldVerifyDiagnostics() const { return verifyDiagnosticsFlag; }
+
+ /// Set whether to run the verifier after each transformation pass.
+ MlirOptMainConfig &verifyPasses(bool verify) {
+ verifyPassesFlag = verify;
+ return *this;
+ }
+ bool shouldVerifyPasses() const { return verifyPassesFlag; }
+
+private:
+ /// Allow operation with no registered dialects.
+ /// This option is for convenience during testing only and discouraged in
+ /// general.
+ bool allowUnregisteredDialectsFlag = false;
+
+ /// Print the pipeline that will be run.
+ bool dumpPassPipelineFlag = false;
+
+ /// Emit bytecode instead of textual assembly when generating output.
+ bool emitBytecodeFlag = false;
+
+ /// The callback to populate the pass manager.
+ std::function<LogicalResult(PassManager &)> passPipelineCallback;
+
+ /// Deprecated.
+ bool preloadDialectsInContextFlag = false;
+
+ /// Show the registered dialects before trying to load the input file.
+ bool showDialectsFlag = false;
+
+ /// Split the input file based on the `// -----` marker into pieces and
+ /// process each chunk independently.
+ bool splitInputFileFlag = false;
+
+ /// Use an implicit top-level module op during parsing.
+ bool useImplicitModuleFlag = true;
+
+ /// Set whether to check that emitted diagnostics match `expected-*` lines on
+ /// the corresponding line. This is meant for implementing diagnostic tests.
+ bool verifyDiagnosticsFlag = false;
+
+ /// Run the verifier after each transformation pass.
+ bool verifyPassesFlag = true;
+};
+
/// This defines the function type used to setup the pass manager. This can be
/// used to pass in a callback to setup a default pass pipeline to be applied on
/// the loaded IR.
using PassPipelineFn = llvm::function_ref<LogicalResult(PassManager &pm)>;
-/// Perform the core processing behind `mlir-opt`:
+/// Perform the core processing behind `mlir-opt`.
/// - outputStream is the stream where the resulting IR is printed.
/// - buffer is the in-memory file to parser and process.
-/// - passPipeline is the specification of the pipeline that will be applied.
/// - registry should contain all the dialects that can be parsed in the source.
-/// - splitInputFile will look for a "-----" marker in the input file, and load
-/// each chunk in an individual ModuleOp processed separately.
-/// - verifyDiagnostics enables a verification mode where comments starting with
-/// "expected-(error|note|remark|warning)" are parsed in the input and matched
-/// against emitted diagnostics.
-/// - verifyPasses enables the IR verifier in-between each pass in the pipeline.
-/// - allowUnregisteredDialects allows to parse and create operation without
-/// registering the Dialect in the MLIRContext.
-/// - preloadDialectsInContext will trigger the upfront loading of all
-/// dialects from the global registry in the MLIRContext. This option is
-/// deprecated and will be removed soon.
-/// - emitBytecode will generate bytecode output instead of text.
-/// - implicitModule will enable implicit addition of a top-level
-/// 'builtin.module' if one doesn't already exist.
-/// - dumpPassPipeline will dump the pipeline being run to stderr
+/// - config contains the configuration options for the tool.
+LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
+ std::unique_ptr<llvm::MemoryBuffer> buffer,
+ DialectRegistry ®istry,
+ const MlirOptMainConfig &config);
+
+/// Perform the core processing behind `mlir-opt`.
+/// This API is deprecated, use the MlirOptMainConfig version above instead.
LogicalResult
MlirOptMain(llvm::raw_ostream &outputStream,
std::unique_ptr<llvm::MemoryBuffer> buffer,
@@ -63,9 +186,8 @@ MlirOptMain(llvm::raw_ostream &outputStream,
bool preloadDialectsInContext = false, bool emitBytecode = false,
bool implicitModule = false, bool dumpPassPipeline = false);
-/// Support a callback to setup the pass manager.
-/// - passManagerSetupFn is the callback invoked to setup the pass manager to
-/// apply on the loaded IR.
+/// Perform the core processing behind `mlir-opt`.
+/// This API is deprecated, use the MlirOptMainConfig version above instead.
LogicalResult MlirOptMain(
llvm::raw_ostream &outputStream, std::unique_ptr<llvm::MemoryBuffer> buffer,
PassPipelineFn passManagerSetupFn, DialectRegistry ®istry,
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 60b5a42eb0eb1..75eefb19991ab 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -40,6 +40,25 @@
using namespace mlir;
using namespace llvm;
+MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser(
+ const PassPipelineCLParser &passPipeline) {
+ passPipelineCallback = [&](PassManager &pm) {
+ auto errorHandler = [&](const Twine &msg) {
+ emitError(UnknownLoc::get(pm.getContext())) << msg;
+ return failure();
+ };
+ if (failed(passPipeline.addToPipeline(pm, errorHandler)))
+ return failure();
+ if (this->shouldDumpPassPipeline()) {
+
+ pm.dump();
+ llvm::errs() << "\n";
+ }
+ return success();
+ };
+ return *this;
+}
+
/// Perform the actions on the input file indicated by the command line flags
/// within the specified context.
///
@@ -47,10 +66,9 @@ using namespace llvm;
/// passes, then prints the output.
///
static LogicalResult
-performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
+performActions(raw_ostream &os,
const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
- MLIRContext *context, PassPipelineFn passManagerSetupFn,
- bool emitBytecode, bool implicitModule) {
+ MLIRContext *context, const MlirOptMainConfig &config) {
DefaultTimingManager tm;
applyDefaultTimingManagerCLOptions(tm);
TimingScope timing = tm.getRootScope();
@@ -66,13 +84,14 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
// untouched.
PassReproducerOptions reproOptions;
FallbackAsmResourceMap fallbackResourceMap;
- ParserConfig config(context, /*verifyAfterParse=*/true, &fallbackResourceMap);
- reproOptions.attachResourceParser(config);
+ ParserConfig parseConfig(context, /*verifyAfterParse=*/true,
+ &fallbackResourceMap);
+ reproOptions.attachResourceParser(parseConfig);
// Parse the input file and reset the context threading state.
TimingScope parserTiming = timing.nest("Parser");
- OwningOpRef<Operation *> op =
- parseSourceFileForTool(sourceMgr, config, implicitModule);
+ OwningOpRef<Operation *> op = parseSourceFileForTool(
+ sourceMgr, parseConfig, config.shouldUseImplicitModule());
context->enableMultithreading(wasThreadingEnabled);
if (!op)
return failure();
@@ -80,10 +99,10 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
// Prepare the pass manager, applying command-line and reproducer options.
PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
- pm.enableVerifier(verifyPasses);
+ pm.enableVerifier(config.shouldVerifyPasses());
applyPassManagerCLOptions(pm);
pm.enableTiming(timing);
- if (failed(reproOptions.apply(pm)) || failed(passManagerSetupFn(pm)))
+ if (failed(reproOptions.apply(pm)) || failed(config.setupPassPipeline(pm)))
return failure();
// Run the pipeline.
@@ -92,7 +111,7 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
// Print the output.
TimingScope outputTiming = timing.nest("Output");
- if (emitBytecode) {
+ if (config.shouldEmitBytecode()) {
BytecodeWriterConfig writerConfig(fallbackResourceMap);
writeBytecodeToFile(op.get(), os, writerConfig);
} else {
@@ -106,13 +125,11 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
/// Parses the memory buffer. If successfully, run a series of passes against
/// it and print the result.
-static LogicalResult
-processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
- bool verifyDiagnostics, bool verifyPasses,
- bool allowUnregisteredDialects, bool preloadDialectsInContext,
- bool emitBytecode, bool implicitModule,
- PassPipelineFn passManagerSetupFn, DialectRegistry ®istry,
- llvm::ThreadPool *threadPool) {
+static LogicalResult processBuffer(raw_ostream &os,
+ std::unique_ptr<MemoryBuffer> ownedBuffer,
+ const MlirOptMainConfig &config,
+ DialectRegistry ®istry,
+ llvm::ThreadPool *threadPool) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
auto sourceMgr = std::make_shared<SourceMgr>();
sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@@ -124,20 +141,18 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
context.setThreadPool(*threadPool);
// Parse the input file.
- if (preloadDialectsInContext)
+ if (config.shouldPreloadDialectsInContext())
context.loadAllAvailableDialects();
- context.allowUnregisteredDialects(allowUnregisteredDialects);
- if (verifyDiagnostics)
+ context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects());
+ if (config.shouldVerifyDiagnostics())
context.printOpOnDiagnostic(false);
context.getDebugActionManager().registerActionHandler<DebugCounter>();
// If we are in verify diagnostics mode then we have a lot of work to do,
// otherwise just perform the actions without worrying about it.
- if (!verifyDiagnostics) {
+ if (!config.shouldVerifyDiagnostics()) {
SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
- return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
- &context, passManagerSetupFn, emitBytecode,
- implicitModule);
+ return performActions(os, sourceMgr, &context, config);
}
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);
@@ -145,22 +160,17 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
// Do any processing requested by command line flags. We don't care whether
// these actions succeed or fail, we only care what diagnostics they produce
// and whether they match our expectations.
- (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
- passManagerSetupFn, emitBytecode, implicitModule);
+ (void)performActions(os, sourceMgr, &context, config);
// Verify the diagnostic handler to make sure that each of the diagnostics
// matched.
return sourceMgrHandler.verify();
}
-LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
- std::unique_ptr<MemoryBuffer> buffer,
- PassPipelineFn passManagerSetupFn,
- DialectRegistry ®istry, bool splitInputFile,
- bool verifyDiagnostics, bool verifyPasses,
- bool allowUnregisteredDialects,
- bool preloadDialectsInContext,
- bool emitBytecode, bool implicitModule) {
+LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
+ std::unique_ptr<llvm::MemoryBuffer> buffer,
+ DialectRegistry ®istry,
+ const MlirOptMainConfig &config) {
// The split-input-file mode is a very specific mode that slices the file
// up into small pieces and checks each independently.
// We use an explicit threadpool to avoid creating and joining/destroying
@@ -177,13 +187,32 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer,
raw_ostream &os) {
- return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
- verifyPasses, allowUnregisteredDialects,
- preloadDialectsInContext, emitBytecode, implicitModule,
- passManagerSetupFn, registry, threadPool);
+ return processBuffer(os, std::move(chunkBuffer), config, registry,
+ threadPool);
};
return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
- splitInputFile, /*insertMarkerInOutput=*/true);
+ config.shouldSplitInputFile(),
+ /*insertMarkerInOutput=*/true);
+}
+
+LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
+ std::unique_ptr<MemoryBuffer> buffer,
+ PassPipelineFn passManagerSetupFn,
+ DialectRegistry ®istry, bool splitInputFile,
+ bool verifyDiagnostics, bool verifyPasses,
+ bool allowUnregisteredDialects,
+ bool preloadDialectsInContext,
+ bool emitBytecode, bool implicitModule) {
+ return MlirOptMain(outputStream, std::move(buffer), registry,
+ MlirOptMainConfig{}
+ .splitInputFile(splitInputFile)
+ .verifyDiagnostics(verifyDiagnostics)
+ .verifyPasses(verifyPasses)
+ .allowUnregisteredDialects(allowUnregisteredDialects)
+ .preloadDialectsInContext(preloadDialectsInContext)
+ .emitBytecode(emitBytecode)
+ .useImplicitModule(implicitModule)
+ .setPassPipelineSetupFn(passManagerSetupFn));
}
LogicalResult mlir::MlirOptMain(
@@ -192,23 +221,17 @@ LogicalResult mlir::MlirOptMain(
bool splitInputFile, bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects, bool preloadDialectsInContext,
bool emitBytecode, bool implicitModule, bool dumpPassPipeline) {
- auto passManagerSetupFn = [&](PassManager &pm) {
- auto errorHandler = [&](const Twine &msg) {
- emitError(UnknownLoc::get(pm.getContext())) << msg;
- return failure();
- };
- if (failed(passPipeline.addToPipeline(pm, errorHandler)))
- return failure();
- if (dumpPassPipeline) {
- pm.dump();
- llvm::errs() << "\n";
- }
- return success();
- };
- return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn,
- registry, splitInputFile, verifyDiagnostics, verifyPasses,
- allowUnregisteredDialects, preloadDialectsInContext,
- emitBytecode, implicitModule);
+ return MlirOptMain(outputStream, std::move(buffer), registry,
+ MlirOptMainConfig{}
+ .splitInputFile(splitInputFile)
+ .verifyDiagnostics(verifyDiagnostics)
+ .verifyPasses(verifyPasses)
+ .allowUnregisteredDialects(allowUnregisteredDialects)
+ .preloadDialectsInContext(preloadDialectsInContext)
+ .emitBytecode(emitBytecode)
+ .useImplicitModule(implicitModule)
+ .dumpPassPipeline(dumpPassPipeline)
+ .setPassPipelineParser(passPipeline));
}
LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
@@ -301,12 +324,19 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
llvm::errs() << errorMessage << "\n";
return failure();
}
-
- if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
- splitInputFile, verifyDiagnostics, verifyPasses,
- allowUnregisteredDialects, preloadDialectsInContext,
- emitBytecode, /*implicitModule=*/!noImplicitModule,
- dumpPassPipeline)))
+ // Setup the configuration for the main function.
+ MlirOptMainConfig config;
+ config.setPassPipelineParser(passPipeline)
+ .splitInputFile(splitInputFile)
+ .verifyDiagnostics(verifyDiagnostics)
+ .verifyPasses(verifyPasses)
+ .allowUnregisteredDialects(allowUnregisteredDialects)
+ .preloadDialectsInContext(preloadDialectsInContext)
+ .emitBytecode(emitBytecode)
+ .useImplicitModule(!noImplicitModule)
+ .dumpPassPipeline(dumpPassPipeline);
+
+ if (failed(MlirOptMain(output->os(), std::move(file), registry, config)))
return failure();
// Keep the output file if the invocation of MlirOptMain was successful.
More information about the Mlir-commits
mailing list