[Mlir-commits] [mlir] 0d8df98 - [mlir] Allow for using OpPassManager in pass options

River Riddle llvmlistbot at llvm.org
Sat Apr 2 01:07:49 PDT 2022


Author: River Riddle
Date: 2022-04-02T00:45:11-07:00
New Revision: 0d8df98035c8b35c5241e15eb8c4f42e6a008fa2

URL: https://github.com/llvm/llvm-project/commit/0d8df98035c8b35c5241e15eb8c4f42e6a008fa2
DIFF: https://github.com/llvm/llvm-project/commit/0d8df98035c8b35c5241e15eb8c4f42e6a008fa2.diff

LOG: [mlir] Allow for using OpPassManager in pass options

This significantly simplifies the boilerplate necessary for passes
to define nested pass pipelines.

Differential Revision: https://reviews.llvm.org/D122880

Added: 
    

Modified: 
    mlir/include/mlir/Pass/PassManager.h
    mlir/include/mlir/Pass/PassOptions.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Pass/Pass.cpp
    mlir/lib/Pass/PassRegistry.cpp
    mlir/lib/Transforms/Inliner.cpp
    mlir/lib/Transforms/PassDetail.h
    mlir/test/Pass/crash-recovery.mlir
    mlir/test/Pass/pipeline-options-parsing.mlir
    mlir/test/Transforms/inlining.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index f2c42bf9140fd..13b127c360f15 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -73,6 +73,9 @@ class OpPassManager {
     return {begin(), end()};
   }
 
+  /// Returns true if the pass manager has no passes.
+  bool empty() const { return begin() == end(); }
+
   /// Nest a new operation pass manager for the given operation kind under this
   /// pass manager.
   OpPassManager &nest(StringAttr nestedName);
@@ -110,7 +113,7 @@ class OpPassManager {
   /// of pipelines.
   /// Note: The quality of the string representation depends entirely on the
   /// the correctness of per-pass overrides of Pass::printAsTextualPipeline.
-  void printAsTextualPipeline(raw_ostream &os);
+  void printAsTextualPipeline(raw_ostream &os) const;
 
   /// Raw dump of the pass manager to llvm::errs().
   void dump();

diff  --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h
index 53e0060333dce..3ee91c7d463be 100644
--- a/mlir/include/mlir/Pass/PassOptions.h
+++ b/mlir/include/mlir/Pass/PassOptions.h
@@ -23,6 +23,8 @@
 #include <memory>
 
 namespace mlir {
+class OpPassManager;
+
 namespace detail {
 namespace pass_options {
 /// Parse a string containing a list of comma-delimited elements, invoking the
@@ -158,7 +160,7 @@ class PassOptions : protected llvm::cl::SubCommand {
         public OptionBase {
   public:
     template <typename... Args>
-    Option(PassOptions &parent, StringRef arg, Args &&... args)
+    Option(PassOptions &parent, StringRef arg, Args &&...args)
         : llvm::cl::opt<DataType, /*ExternalStorage=*/false, OptionParser>(
               arg, llvm::cl::sub(parent), std::forward<Args>(args)...) {
       assert(!this->isPositional() && !this->isSink() &&
@@ -319,7 +321,8 @@ class PassOptions : protected llvm::cl::SubCommand {
 /// struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> {
 ///   ListOption<int> someListFlag{*this, "flag-name", llvm::cl::desc("...")};
 /// };
-template <typename T> class PassPipelineOptions : public detail::PassOptions {
+template <typename T>
+class PassPipelineOptions : public detail::PassOptions {
 public:
   /// Factory that parses the provided options and returns a unique_ptr to the
   /// struct.
@@ -335,7 +338,6 @@ template <typename T> class PassPipelineOptions : public detail::PassOptions {
 /// any options.
 struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> {
 };
-
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
@@ -407,8 +409,92 @@ class parser<SmallVector<T, N>>
 public:
   parser(Option &opt) : detail::VectorParserBase<SmallVector<T, N>, T>(opt) {}
 };
-} // end namespace cl
-} // end namespace llvm
 
-#endif // MLIR_PASS_PASSOPTIONS_H_
+//===----------------------------------------------------------------------===//
+// OpPassManager: OptionValue
 
+template <>
+struct OptionValue<mlir::OpPassManager> final : GenericOptionValue {
+  using WrapperType = mlir::OpPassManager;
+
+  OptionValue();
+  OptionValue(const mlir::OpPassManager &value);
+  OptionValue<mlir::OpPassManager> &operator=(const mlir::OpPassManager &rhs);
+  ~OptionValue();
+
+  /// Returns if the current option has a value.
+  bool hasValue() const { return value.get(); }
+
+  /// Returns the current value of the option.
+  mlir::OpPassManager &getValue() const {
+    assert(hasValue() && "invalid option value");
+    return *value;
+  }
+
+  /// Set the value of the option.
+  void setValue(const mlir::OpPassManager &newValue);
+  void setValue(StringRef pipelineStr);
+
+  /// Compare the option with the provided value.
+  bool compare(const mlir::OpPassManager &rhs) const;
+  bool compare(const GenericOptionValue &rhs) const override {
+    const auto &rhsOV =
+        static_cast<const OptionValue<mlir::OpPassManager> &>(rhs);
+    if (!rhsOV.hasValue())
+      return false;
+    return compare(rhsOV.getValue());
+  }
+
+private:
+  void anchor() override;
+
+  /// The underlying pass manager. We use a unique_ptr to avoid the need for the
+  /// full type definition.
+  std::unique_ptr<mlir::OpPassManager> value;
+};
+
+//===----------------------------------------------------------------------===//
+// OpPassManager: Parser
+
+extern template class basic_parser<mlir::OpPassManager>;
+
+template <>
+class parser<mlir::OpPassManager> : public basic_parser<mlir::OpPassManager> {
+public:
+  /// A utility struct used when parsing a pass manager that prevents the need
+  /// for a default constructor on OpPassManager.
+  struct ParsedPassManager {
+    ParsedPassManager();
+    ParsedPassManager(ParsedPassManager &&);
+    ~ParsedPassManager();
+    operator const mlir::OpPassManager &() const {
+      assert(value && "parsed value was invalid");
+      return *value;
+    }
+
+    std::unique_ptr<mlir::OpPassManager> value;
+  };
+  using parser_data_type = ParsedPassManager;
+  using OptVal = OptionValue<mlir::OpPassManager>;
+
+  parser(Option &opt) : basic_parser(opt) {}
+
+  bool parse(Option &, StringRef, StringRef arg, ParsedPassManager &value);
+
+  /// Print an instance of the underling option value to the given stream.
+  static void print(raw_ostream &os, const mlir::OpPassManager &value);
+
+  // Overload in subclass to provide a better default value.
+  StringRef getValueName() const override { return "pass-manager"; }
+
+  void printOptionDiff(const Option &opt, mlir::OpPassManager &pm,
+                       const OptVal &defaultValue, size_t globalWidth) const;
+
+  // An out-of-line virtual method to provide a 'home' for this class.
+  void anchor() override;
+};
+
+} // namespace cl
+} // namespace llvm
+
+#endif // MLIR_PASS_PASSOPTIONS_H_

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index cda2997c5e621..4569ac27856f4 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -83,7 +83,7 @@ def Inliner : Pass<"inline"> {
   let options = [
     Option<"defaultPipelineStr", "default-pipeline", "std::string",
            /*default=*/"", "The default optimizer pipeline used for callables">,
-    ListOption<"opPipelineStrs", "op-pipelines", "std::string",
+    ListOption<"opPipelineList", "op-pipelines", "OpPassManager",
                "Callable operation specific optimizer pipelines (in the form "
                "of `dialect.op(pipeline)`)">,
     Option<"maxInliningIterations", "max-iterations", "unsigned",

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 7256f44b6adb4..8b4310272d871 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -54,12 +54,14 @@ void Pass::copyOptionValuesFrom(const Pass *other) {
 void Pass::printAsTextualPipeline(raw_ostream &os) {
   // Special case for adaptors to use the 'op_name(sub_passes)' format.
   if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
-    llvm::interleaveComma(adaptor->getPassManagers(), os,
-                          [&](OpPassManager &pm) {
-                            os << pm.getOpName() << "(";
-                            pm.printAsTextualPipeline(os);
-                            os << ")";
-                          });
+    llvm::interleave(
+        adaptor->getPassManagers(),
+        [&](OpPassManager &pm) {
+          os << pm.getOpName() << "(";
+          pm.printAsTextualPipeline(os);
+          os << ")";
+        },
+        [&] { os << ","; });
     return;
   }
   // Otherwise, print the pass argument followed by its options. If the pass
@@ -295,14 +297,17 @@ OperationName OpPassManager::getOpName(MLIRContext &context) const {
 /// Prints out the given passes as the textual representation of a pipeline.
 static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
                                    raw_ostream &os) {
-  llvm::interleaveComma(passes, os, [&](const std::unique_ptr<Pass> &pass) {
-    pass->printAsTextualPipeline(os);
-  });
+  llvm::interleave(
+      passes,
+      [&](const std::unique_ptr<Pass> &pass) {
+        pass->printAsTextualPipeline(os);
+      },
+      [&] { os << ","; });
 }
 
 /// Prints out the passes of the pass manager as the textual representation
 /// of pipelines.
-void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
+void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
   ::printAsTextualPipeline(impl->passes, os);
 }
 

diff  --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 5088ae244c729..8a4f117451dc1 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -332,6 +332,104 @@ size_t detail::PassOptions::getOptionWidth() const {
   return max;
 }
 
+//===----------------------------------------------------------------------===//
+// MLIR Options
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// OpPassManager: OptionValue
+
+llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
+llvm::cl::OptionValue<OpPassManager>::OptionValue(
+    const mlir::OpPassManager &value) {
+  setValue(value);
+}
+llvm::cl::OptionValue<OpPassManager> &
+llvm::cl::OptionValue<OpPassManager>::operator=(
+    const mlir::OpPassManager &rhs) {
+  setValue(rhs);
+  return *this;
+}
+
+llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
+
+void llvm::cl::OptionValue<OpPassManager>::setValue(
+    const OpPassManager &newValue) {
+  if (hasValue())
+    *value = newValue;
+  else
+    value = std::make_unique<mlir::OpPassManager>(newValue);
+}
+void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
+  FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
+  assert(succeeded(pipeline) && "invalid pass pipeline");
+  setValue(*pipeline);
+}
+
+bool llvm::cl::OptionValue<OpPassManager>::compare(
+    const mlir::OpPassManager &rhs) const {
+  std::string lhsStr, rhsStr;
+  {
+    raw_string_ostream lhsStream(lhsStr);
+    value->printAsTextualPipeline(lhsStream);
+
+    raw_string_ostream rhsStream(rhsStr);
+    rhs.printAsTextualPipeline(rhsStream);
+  }
+
+  // Use the textual format for pipeline comparisons.
+  return lhsStr == rhsStr;
+}
+
+void llvm::cl::OptionValue<OpPassManager>::anchor() {}
+
+//===----------------------------------------------------------------------===//
+// OpPassManager: Parser
+
+namespace llvm {
+namespace cl {
+template class basic_parser<OpPassManager>;
+} // namespace cl
+} // namespace llvm
+
+bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
+                                            ParsedPassManager &value) {
+  FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
+  if (failed(pipeline))
+    return true;
+  value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
+  return false;
+}
+
+void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
+                                            const OpPassManager &value) {
+  value.printAsTextualPipeline(os);
+}
+
+void llvm::cl::parser<OpPassManager>::printOptionDiff(
+    const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
+    size_t globalWidth) const {
+  printOptionName(opt, globalWidth);
+  outs() << "= ";
+  pm.printAsTextualPipeline(outs());
+
+  if (defaultValue.hasValue()) {
+    outs().indent(2) << " (default: ";
+    defaultValue.getValue().printAsTextualPipeline(outs());
+    outs() << ")";
+  }
+  outs() << "\n";
+}
+
+void llvm::cl::parser<OpPassManager>::anchor() {}
+
+llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() =
+    default;
+llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager(
+    ParsedPassManager &&) = default;
+llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() =
+    default;
+
 //===----------------------------------------------------------------------===//
 // TextualPassPipeline Parser
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 145a113fd1a35..ae1e2aabd2381 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -585,14 +585,8 @@ InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
     return;
 
   // Update the option for the op specific optimization pipelines.
-  for (auto &it : opPipelines) {
-    std::string pipeline;
-    llvm::raw_string_ostream pipelineOS(pipeline);
-    pipelineOS << it.getKey() << "(";
-    it.second.printAsTextualPipeline(pipelineOS);
-    pipelineOS << ")";
-    opPipelineStrs.addValue(pipeline);
-  }
+  for (auto &it : opPipelines)
+    opPipelineList.addValue(it.second);
   this->opPipelines.emplace_back(std::move(opPipelines));
 }
 
@@ -751,15 +745,9 @@ LogicalResult InlinerPass::initializeOptions(StringRef options) {
 
   // Initialize the op specific pass pipelines.
   llvm::StringMap<OpPassManager> pipelines;
-  for (StringRef pipeline : opPipelineStrs) {
-    // Skip empty pipelines.
-    if (pipeline.empty())
-      continue;
-    FailureOr<OpPassManager> pm = parsePassPipeline(pipeline);
-    if (failed(pm))
-      return failure();
-    pipelines.try_emplace(pm->getOpName(), std::move(*pm));
-  }
+  for (OpPassManager pipeline : opPipelineList)
+    if (!pipeline.empty())
+      pipelines.try_emplace(pipeline.getOpName(), pipeline);
   opPipelines.assign({std::move(pipelines)});
 
   return success();

diff  --git a/mlir/lib/Transforms/PassDetail.h b/mlir/lib/Transforms/PassDetail.h
index 7e1fedc136c9e..7c1f53929fe45 100644
--- a/mlir/lib/Transforms/PassDetail.h
+++ b/mlir/lib/Transforms/PassDetail.h
@@ -10,6 +10,7 @@
 #define TRANSFORMS_PASSDETAIL_H_
 
 #include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/Passes.h"
 
 namespace mlir {

diff  --git a/mlir/test/Pass/crash-recovery.mlir b/mlir/test/Pass/crash-recovery.mlir
index 31db0fd92abbb..3e379dcaf6a2e 100644
--- a/mlir/test/Pass/crash-recovery.mlir
+++ b/mlir/test/Pass/crash-recovery.mlir
@@ -20,7 +20,7 @@ module @inner_mod1 {
   module @foo {}
 }
 
-// REPRO: configuration: -pass-pipeline='builtin.module(test-module-pass, test-pass-crash)'
+// REPRO: configuration: -pass-pipeline='builtin.module(test-module-pass,test-pass-crash)'
 
 // REPRO: module @inner_mod1
 // REPRO: module @foo {

diff  --git a/mlir/test/Pass/pipeline-options-parsing.mlir b/mlir/test/Pass/pipeline-options-parsing.mlir
index 987ab7ad29025..74eea1a210188 100644
--- a/mlir/test/Pass/pipeline-options-parsing.mlir
+++ b/mlir/test/Pass/pipeline-options-parsing.mlir
@@ -14,4 +14,4 @@
 
 // CHECK_1: test-options-pass{list=1,2,3,4,5 string=nested_pipeline{arg1=10 arg2=" {} " arg3=true} string-list=a,b,c,d}
 // CHECK_2: test-options-pass{list=1 string= string-list=a,b}
-// CHECK_3: builtin.module(func.func(test-options-pass{list=3 string= }), func.func(test-options-pass{list=1,2,3,4 string= }))
+// CHECK_3: builtin.module(func.func(test-options-pass{list=3 string= }),func.func(test-options-pass{list=1,2,3,4 string= }))

diff  --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir
index 4d762edd58441..883fde94a1c61 100644
--- a/mlir/test/Transforms/inlining.mlir
+++ b/mlir/test/Transforms/inlining.mlir
@@ -2,6 +2,7 @@
 // RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s
 // RUN: mlir-opt %s -inline='default-pipeline=''' -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s --check-prefix INLINE-LOC
 // RUN: mlir-opt %s -inline | FileCheck %s --check-prefix INLINE_SIMPLIFY
+// RUN: mlir-opt %s -inline='op-pipelines=func.func(canonicalize,cse)' | FileCheck %s --check-prefix INLINE_SIMPLIFY
 
 // Inline a function that takes an argument.
 func @func_with_arg(%c : i32) -> i32 {


        


More information about the Mlir-commits mailing list