[Mlir-commits] [mlir] 0d8df98 - [mlir] Allow for using OpPassManager in pass options
River Riddle
llvmlistbot at llvm.org
Sat Apr 2 01:07:49 PDT 2022
Author: River Riddle
Date: 2022-04-02T00:45:11-07:00
New Revision: 0d8df98035c8b35c5241e15eb8c4f42e6a008fa2
URL: https://github.com/llvm/llvm-project/commit/0d8df98035c8b35c5241e15eb8c4f42e6a008fa2
DIFF: https://github.com/llvm/llvm-project/commit/0d8df98035c8b35c5241e15eb8c4f42e6a008fa2.diff
LOG: [mlir] Allow for using OpPassManager in pass options
This significantly simplifies the boilerplate necessary for passes
to define nested pass pipelines.
Differential Revision: https://reviews.llvm.org/D122880
Added:
Modified:
mlir/include/mlir/Pass/PassManager.h
mlir/include/mlir/Pass/PassOptions.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassRegistry.cpp
mlir/lib/Transforms/Inliner.cpp
mlir/lib/Transforms/PassDetail.h
mlir/test/Pass/crash-recovery.mlir
mlir/test/Pass/pipeline-options-parsing.mlir
mlir/test/Transforms/inlining.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index f2c42bf9140fd..13b127c360f15 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -73,6 +73,9 @@ class OpPassManager {
return {begin(), end()};
}
+ /// Returns true if the pass manager has no passes.
+ bool empty() const { return begin() == end(); }
+
/// Nest a new operation pass manager for the given operation kind under this
/// pass manager.
OpPassManager &nest(StringAttr nestedName);
@@ -110,7 +113,7 @@ class OpPassManager {
/// of pipelines.
/// Note: The quality of the string representation depends entirely on the
/// the correctness of per-pass overrides of Pass::printAsTextualPipeline.
- void printAsTextualPipeline(raw_ostream &os);
+ void printAsTextualPipeline(raw_ostream &os) const;
/// Raw dump of the pass manager to llvm::errs().
void dump();
diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h
index 53e0060333dce..3ee91c7d463be 100644
--- a/mlir/include/mlir/Pass/PassOptions.h
+++ b/mlir/include/mlir/Pass/PassOptions.h
@@ -23,6 +23,8 @@
#include <memory>
namespace mlir {
+class OpPassManager;
+
namespace detail {
namespace pass_options {
/// Parse a string containing a list of comma-delimited elements, invoking the
@@ -158,7 +160,7 @@ class PassOptions : protected llvm::cl::SubCommand {
public OptionBase {
public:
template <typename... Args>
- Option(PassOptions &parent, StringRef arg, Args &&... args)
+ Option(PassOptions &parent, StringRef arg, Args &&...args)
: llvm::cl::opt<DataType, /*ExternalStorage=*/false, OptionParser>(
arg, llvm::cl::sub(parent), std::forward<Args>(args)...) {
assert(!this->isPositional() && !this->isSink() &&
@@ -319,7 +321,8 @@ class PassOptions : protected llvm::cl::SubCommand {
/// struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> {
/// ListOption<int> someListFlag{*this, "flag-name", llvm::cl::desc("...")};
/// };
-template <typename T> class PassPipelineOptions : public detail::PassOptions {
+template <typename T>
+class PassPipelineOptions : public detail::PassOptions {
public:
/// Factory that parses the provided options and returns a unique_ptr to the
/// struct.
@@ -335,7 +338,6 @@ template <typename T> class PassPipelineOptions : public detail::PassOptions {
/// any options.
struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> {
};
-
} // namespace mlir
//===----------------------------------------------------------------------===//
@@ -407,8 +409,92 @@ class parser<SmallVector<T, N>>
public:
parser(Option &opt) : detail::VectorParserBase<SmallVector<T, N>, T>(opt) {}
};
-} // end namespace cl
-} // end namespace llvm
-#endif // MLIR_PASS_PASSOPTIONS_H_
+//===----------------------------------------------------------------------===//
+// OpPassManager: OptionValue
+template <>
+struct OptionValue<mlir::OpPassManager> final : GenericOptionValue {
+ using WrapperType = mlir::OpPassManager;
+
+ OptionValue();
+ OptionValue(const mlir::OpPassManager &value);
+ OptionValue<mlir::OpPassManager> &operator=(const mlir::OpPassManager &rhs);
+ ~OptionValue();
+
+ /// Returns if the current option has a value.
+ bool hasValue() const { return value.get(); }
+
+ /// Returns the current value of the option.
+ mlir::OpPassManager &getValue() const {
+ assert(hasValue() && "invalid option value");
+ return *value;
+ }
+
+ /// Set the value of the option.
+ void setValue(const mlir::OpPassManager &newValue);
+ void setValue(StringRef pipelineStr);
+
+ /// Compare the option with the provided value.
+ bool compare(const mlir::OpPassManager &rhs) const;
+ bool compare(const GenericOptionValue &rhs) const override {
+ const auto &rhsOV =
+ static_cast<const OptionValue<mlir::OpPassManager> &>(rhs);
+ if (!rhsOV.hasValue())
+ return false;
+ return compare(rhsOV.getValue());
+ }
+
+private:
+ void anchor() override;
+
+ /// The underlying pass manager. We use a unique_ptr to avoid the need for the
+ /// full type definition.
+ std::unique_ptr<mlir::OpPassManager> value;
+};
+
+//===----------------------------------------------------------------------===//
+// OpPassManager: Parser
+
+extern template class basic_parser<mlir::OpPassManager>;
+
+template <>
+class parser<mlir::OpPassManager> : public basic_parser<mlir::OpPassManager> {
+public:
+ /// A utility struct used when parsing a pass manager that prevents the need
+ /// for a default constructor on OpPassManager.
+ struct ParsedPassManager {
+ ParsedPassManager();
+ ParsedPassManager(ParsedPassManager &&);
+ ~ParsedPassManager();
+ operator const mlir::OpPassManager &() const {
+ assert(value && "parsed value was invalid");
+ return *value;
+ }
+
+ std::unique_ptr<mlir::OpPassManager> value;
+ };
+ using parser_data_type = ParsedPassManager;
+ using OptVal = OptionValue<mlir::OpPassManager>;
+
+ parser(Option &opt) : basic_parser(opt) {}
+
+ bool parse(Option &, StringRef, StringRef arg, ParsedPassManager &value);
+
+ /// Print an instance of the underling option value to the given stream.
+ static void print(raw_ostream &os, const mlir::OpPassManager &value);
+
+ // Overload in subclass to provide a better default value.
+ StringRef getValueName() const override { return "pass-manager"; }
+
+ void printOptionDiff(const Option &opt, mlir::OpPassManager &pm,
+ const OptVal &defaultValue, size_t globalWidth) const;
+
+ // An out-of-line virtual method to provide a 'home' for this class.
+ void anchor() override;
+};
+
+} // namespace cl
+} // namespace llvm
+
+#endif // MLIR_PASS_PASSOPTIONS_H_
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index cda2997c5e621..4569ac27856f4 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -83,7 +83,7 @@ def Inliner : Pass<"inline"> {
let options = [
Option<"defaultPipelineStr", "default-pipeline", "std::string",
/*default=*/"", "The default optimizer pipeline used for callables">,
- ListOption<"opPipelineStrs", "op-pipelines", "std::string",
+ ListOption<"opPipelineList", "op-pipelines", "OpPassManager",
"Callable operation specific optimizer pipelines (in the form "
"of `dialect.op(pipeline)`)">,
Option<"maxInliningIterations", "max-iterations", "unsigned",
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 7256f44b6adb4..8b4310272d871 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -54,12 +54,14 @@ void Pass::copyOptionValuesFrom(const Pass *other) {
void Pass::printAsTextualPipeline(raw_ostream &os) {
// Special case for adaptors to use the 'op_name(sub_passes)' format.
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
- llvm::interleaveComma(adaptor->getPassManagers(), os,
- [&](OpPassManager &pm) {
- os << pm.getOpName() << "(";
- pm.printAsTextualPipeline(os);
- os << ")";
- });
+ llvm::interleave(
+ adaptor->getPassManagers(),
+ [&](OpPassManager &pm) {
+ os << pm.getOpName() << "(";
+ pm.printAsTextualPipeline(os);
+ os << ")";
+ },
+ [&] { os << ","; });
return;
}
// Otherwise, print the pass argument followed by its options. If the pass
@@ -295,14 +297,17 @@ OperationName OpPassManager::getOpName(MLIRContext &context) const {
/// Prints out the given passes as the textual representation of a pipeline.
static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
raw_ostream &os) {
- llvm::interleaveComma(passes, os, [&](const std::unique_ptr<Pass> &pass) {
- pass->printAsTextualPipeline(os);
- });
+ llvm::interleave(
+ passes,
+ [&](const std::unique_ptr<Pass> &pass) {
+ pass->printAsTextualPipeline(os);
+ },
+ [&] { os << ","; });
}
/// Prints out the passes of the pass manager as the textual representation
/// of pipelines.
-void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
+void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
::printAsTextualPipeline(impl->passes, os);
}
diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 5088ae244c729..8a4f117451dc1 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -332,6 +332,104 @@ size_t detail::PassOptions::getOptionWidth() const {
return max;
}
+//===----------------------------------------------------------------------===//
+// MLIR Options
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// OpPassManager: OptionValue
+
+llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
+llvm::cl::OptionValue<OpPassManager>::OptionValue(
+ const mlir::OpPassManager &value) {
+ setValue(value);
+}
+llvm::cl::OptionValue<OpPassManager> &
+llvm::cl::OptionValue<OpPassManager>::operator=(
+ const mlir::OpPassManager &rhs) {
+ setValue(rhs);
+ return *this;
+}
+
+llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
+
+void llvm::cl::OptionValue<OpPassManager>::setValue(
+ const OpPassManager &newValue) {
+ if (hasValue())
+ *value = newValue;
+ else
+ value = std::make_unique<mlir::OpPassManager>(newValue);
+}
+void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
+ FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
+ assert(succeeded(pipeline) && "invalid pass pipeline");
+ setValue(*pipeline);
+}
+
+bool llvm::cl::OptionValue<OpPassManager>::compare(
+ const mlir::OpPassManager &rhs) const {
+ std::string lhsStr, rhsStr;
+ {
+ raw_string_ostream lhsStream(lhsStr);
+ value->printAsTextualPipeline(lhsStream);
+
+ raw_string_ostream rhsStream(rhsStr);
+ rhs.printAsTextualPipeline(rhsStream);
+ }
+
+ // Use the textual format for pipeline comparisons.
+ return lhsStr == rhsStr;
+}
+
+void llvm::cl::OptionValue<OpPassManager>::anchor() {}
+
+//===----------------------------------------------------------------------===//
+// OpPassManager: Parser
+
+namespace llvm {
+namespace cl {
+template class basic_parser<OpPassManager>;
+} // namespace cl
+} // namespace llvm
+
+bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
+ ParsedPassManager &value) {
+ FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
+ if (failed(pipeline))
+ return true;
+ value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
+ return false;
+}
+
+void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
+ const OpPassManager &value) {
+ value.printAsTextualPipeline(os);
+}
+
+void llvm::cl::parser<OpPassManager>::printOptionDiff(
+ const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
+ size_t globalWidth) const {
+ printOptionName(opt, globalWidth);
+ outs() << "= ";
+ pm.printAsTextualPipeline(outs());
+
+ if (defaultValue.hasValue()) {
+ outs().indent(2) << " (default: ";
+ defaultValue.getValue().printAsTextualPipeline(outs());
+ outs() << ")";
+ }
+ outs() << "\n";
+}
+
+void llvm::cl::parser<OpPassManager>::anchor() {}
+
+llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() =
+ default;
+llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager(
+ ParsedPassManager &&) = default;
+llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() =
+ default;
+
//===----------------------------------------------------------------------===//
// TextualPassPipeline Parser
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 145a113fd1a35..ae1e2aabd2381 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -585,14 +585,8 @@ InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
return;
// Update the option for the op specific optimization pipelines.
- for (auto &it : opPipelines) {
- std::string pipeline;
- llvm::raw_string_ostream pipelineOS(pipeline);
- pipelineOS << it.getKey() << "(";
- it.second.printAsTextualPipeline(pipelineOS);
- pipelineOS << ")";
- opPipelineStrs.addValue(pipeline);
- }
+ for (auto &it : opPipelines)
+ opPipelineList.addValue(it.second);
this->opPipelines.emplace_back(std::move(opPipelines));
}
@@ -751,15 +745,9 @@ LogicalResult InlinerPass::initializeOptions(StringRef options) {
// Initialize the op specific pass pipelines.
llvm::StringMap<OpPassManager> pipelines;
- for (StringRef pipeline : opPipelineStrs) {
- // Skip empty pipelines.
- if (pipeline.empty())
- continue;
- FailureOr<OpPassManager> pm = parsePassPipeline(pipeline);
- if (failed(pm))
- return failure();
- pipelines.try_emplace(pm->getOpName(), std::move(*pm));
- }
+ for (OpPassManager pipeline : opPipelineList)
+ if (!pipeline.empty())
+ pipelines.try_emplace(pipeline.getOpName(), pipeline);
opPipelines.assign({std::move(pipelines)});
return success();
diff --git a/mlir/lib/Transforms/PassDetail.h b/mlir/lib/Transforms/PassDetail.h
index 7e1fedc136c9e..7c1f53929fe45 100644
--- a/mlir/lib/Transforms/PassDetail.h
+++ b/mlir/lib/Transforms/PassDetail.h
@@ -10,6 +10,7 @@
#define TRANSFORMS_PASSDETAIL_H_
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
diff --git a/mlir/test/Pass/crash-recovery.mlir b/mlir/test/Pass/crash-recovery.mlir
index 31db0fd92abbb..3e379dcaf6a2e 100644
--- a/mlir/test/Pass/crash-recovery.mlir
+++ b/mlir/test/Pass/crash-recovery.mlir
@@ -20,7 +20,7 @@ module @inner_mod1 {
module @foo {}
}
-// REPRO: configuration: -pass-pipeline='builtin.module(test-module-pass, test-pass-crash)'
+// REPRO: configuration: -pass-pipeline='builtin.module(test-module-pass,test-pass-crash)'
// REPRO: module @inner_mod1
// REPRO: module @foo {
diff --git a/mlir/test/Pass/pipeline-options-parsing.mlir b/mlir/test/Pass/pipeline-options-parsing.mlir
index 987ab7ad29025..74eea1a210188 100644
--- a/mlir/test/Pass/pipeline-options-parsing.mlir
+++ b/mlir/test/Pass/pipeline-options-parsing.mlir
@@ -14,4 +14,4 @@
// CHECK_1: test-options-pass{list=1,2,3,4,5 string=nested_pipeline{arg1=10 arg2=" {} " arg3=true} string-list=a,b,c,d}
// CHECK_2: test-options-pass{list=1 string= string-list=a,b}
-// CHECK_3: builtin.module(func.func(test-options-pass{list=3 string= }), func.func(test-options-pass{list=1,2,3,4 string= }))
+// CHECK_3: builtin.module(func.func(test-options-pass{list=3 string= }),func.func(test-options-pass{list=1,2,3,4 string= }))
diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir
index 4d762edd58441..883fde94a1c61 100644
--- a/mlir/test/Transforms/inlining.mlir
+++ b/mlir/test/Transforms/inlining.mlir
@@ -2,6 +2,7 @@
// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s
// RUN: mlir-opt %s -inline='default-pipeline=''' -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s --check-prefix INLINE-LOC
// RUN: mlir-opt %s -inline | FileCheck %s --check-prefix INLINE_SIMPLIFY
+// RUN: mlir-opt %s -inline='op-pipelines=func.func(canonicalize,cse)' | FileCheck %s --check-prefix INLINE_SIMPLIFY
// Inline a function that takes an argument.
func @func_with_arg(%c : i32) -> i32 {
More information about the Mlir-commits
mailing list