[Mlir-commits] [mlir] 165c6d1 - [mlir] Add support for parsing nested PassPipelineOptions (#101118)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 9 13:54:04 PDT 2024


Author: Nikhil Kalra
Date: 2024-08-09T13:54:00-07:00
New Revision: 165c6d12519cf66f4ef3f5a00f9b1ed83613ff28

URL: https://github.com/llvm/llvm-project/commit/165c6d12519cf66f4ef3f5a00f9b1ed83613ff28
DIFF: https://github.com/llvm/llvm-project/commit/165c6d12519cf66f4ef3f5a00f9b1ed83613ff28.diff

LOG: [mlir] Add support for parsing nested PassPipelineOptions (#101118)

- Added a default parsing implementation to `PassOptions` to allow
`Option`/`ListOption` to wrap PassOption objects. This is helpful when
creating meta-pipelines (pass pipelines composed of pass pipelines).
- Updated `ListOption` printing to enable round-tripping the output of
`dump-pass-pipeline` back into `mlir-opt` for more complex structures.

Added: 
    

Modified: 
    mlir/include/mlir/Pass/PassOptions.h
    mlir/lib/Pass/PassRegistry.cpp
    mlir/test/Pass/pipeline-options-parsing.mlir
    mlir/test/lib/Pass/TestPassManager.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h
index 6bffa84f7b16b4..a5a3f1c1c19652 100644
--- a/mlir/include/mlir/Pass/PassOptions.h
+++ b/mlir/include/mlir/Pass/PassOptions.h
@@ -139,6 +139,25 @@ class PassOptions : protected llvm::cl::SubCommand {
     }
   };
 
+  /// This is the parser that is used by pass options that wrap PassOptions
+  /// instances. Like GenericOptionParser, this is a thin wrapper around
+  /// llvm::cl::basic_parser.
+  template <typename PassOptionsT>
+  struct PassOptionsParser : public llvm::cl::basic_parser<PassOptionsT> {
+    using llvm::cl::basic_parser<PassOptionsT>::basic_parser;
+    // Parse the options object by delegating to
+    // `PassOptionsT::parseFromString`.
+    bool parse(llvm::cl::Option &, StringRef, StringRef arg,
+               PassOptionsT &value) {
+      return failed(value.parseFromString(arg));
+    }
+
+    // Print the options object by delegating to `PassOptionsT::print`.
+    static void print(llvm::raw_ostream &os, const PassOptionsT &value) {
+      value.print(os);
+    }
+  };
+
   /// Utility methods for printing option values.
   template <typename DataT>
   static void printValue(raw_ostream &os, GenericOptionParser<DataT> &parser,
@@ -154,19 +173,24 @@ class PassOptions : protected llvm::cl::SubCommand {
   }
 
 public:
-  /// The specific parser to use depending on llvm::cl parser used. This is only
-  /// necessary because we need to provide additional methods for certain data
-  /// type parsers.
-  /// TODO: We should upstream the methods in GenericOptionParser to avoid the
-  /// need to do this.
+  /// The specific parser to use. This is necessary because we need to provide
+  /// additional methods for certain data type parsers.
   template <typename DataType>
-  using OptionParser =
+  using OptionParser = std::conditional_t<
+      // If the data type is derived from PassOptions, use the
+      // PassOptionsParser.
+      std::is_base_of_v<PassOptions, DataType>, PassOptionsParser<DataType>,
+      // Otherwise, use GenericOptionParser where it is well formed, and fall
+      // back to llvm::cl::parser otherwise.
+      // TODO: We should upstream the methods in GenericOptionParser to avoid
+      // the  need to do this.
       std::conditional_t<std::is_base_of<llvm::cl::generic_parser_base,
                                          llvm::cl::parser<DataType>>::value,
                          GenericOptionParser<DataType>,
-                         llvm::cl::parser<DataType>>;
+                         llvm::cl::parser<DataType>>>;
 
-  /// This class represents a specific pass option, with a provided data type.
+  /// This class represents a specific pass option, with a provided
+  /// data type.
   template <typename DataType, typename OptionParser = OptionParser<DataType>>
   class Option
       : public llvm::cl::opt<DataType, /*ExternalStorage=*/false, OptionParser>,
@@ -278,11 +302,12 @@ class PassOptions : protected llvm::cl::SubCommand {
       if ((**this).empty())
         return;
 
-      os << this->ArgStr << '=';
+      os << this->ArgStr << "={";
       auto printElementFn = [&](const DataType &value) {
         printValue(os, this->getParser(), value);
       };
       llvm::interleave(*this, os, printElementFn, ",");
+      os << "}";
     }
 
     /// Copy the value from the given option into this one.
@@ -311,7 +336,7 @@ class PassOptions : protected llvm::cl::SubCommand {
 
   /// Print the options held by this struct in a form that can be parsed via
   /// 'parseFromString'.
-  void print(raw_ostream &os);
+  void print(raw_ostream &os) const;
 
   /// Print the help string for the options held by this struct. `descIndent` is
   /// the indent that the descriptions should be aligned.

diff  --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index cb7c15580ec54d..fe842755958418 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Pass/PassManager.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Format.h"
 #include "llvm/Support/ManagedStatic.h"
 #include "llvm/Support/MemoryBuffer.h"
@@ -185,6 +186,31 @@ const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
 // PassOptions
 //===----------------------------------------------------------------------===//
 
+/// Extract an argument from 'options' and update it to point after the arg.
+/// Returns the cleaned argument string.
+static StringRef extractArgAndUpdateOptions(StringRef &options,
+                                            size_t argSize) {
+  StringRef str = options.take_front(argSize).trim();
+  options = options.drop_front(argSize).ltrim();
+
+  // Early exit if there's no escape sequence.
+  if (str.size() <= 2)
+    return str;
+
+  const auto escapePairs = {std::make_pair('\'', '\''),
+                            std::make_pair('"', '"'), std::make_pair('{', '}')};
+  for (const auto &escape : escapePairs) {
+    if (str.front() == escape.first && str.back() == escape.second) {
+      // Drop the escape characters and trim.
+      str = str.drop_front().drop_back().trim();
+      // Don't process additional escape sequences.
+      break;
+    }
+  }
+
+  return str;
+}
+
 LogicalResult detail::pass_options::parseCommaSeparatedList(
     llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
     function_ref<LogicalResult(StringRef)> elementParseFn) {
@@ -213,13 +239,16 @@ LogicalResult detail::pass_options::parseCommaSeparatedList(
   size_t nextElePos = findChar(optionStr, 0, ',');
   while (nextElePos != StringRef::npos) {
     // Process the portion before the comma.
-    if (failed(elementParseFn(optionStr.substr(0, nextElePos))))
+    if (failed(
+            elementParseFn(extractArgAndUpdateOptions(optionStr, nextElePos))))
       return failure();
 
-    optionStr = optionStr.substr(nextElePos + 1);
+    // Drop the leading ','
+    optionStr = optionStr.drop_front();
     nextElePos = findChar(optionStr, 0, ',');
   }
-  return elementParseFn(optionStr.substr(0, nextElePos));
+  return elementParseFn(
+      extractArgAndUpdateOptions(optionStr, optionStr.size()));
 }
 
 /// Out of line virtual function to provide home for the class.
@@ -239,27 +268,6 @@ void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
 /// `options` string pointing after the parsed option].
 static std::tuple<StringRef, StringRef, StringRef>
 parseNextArg(StringRef options) {
-  // Functor used to extract an argument from 'options' and update it to point
-  // after the arg.
-  auto extractArgAndUpdateOptions = [&](size_t argSize) {
-    StringRef str = options.take_front(argSize).trim();
-    options = options.drop_front(argSize).ltrim();
-    // Handle escape sequences
-    if (str.size() > 2) {
-      const auto escapePairs = {std::make_pair('\'', '\''),
-                                std::make_pair('"', '"'),
-                                std::make_pair('{', '}')};
-      for (const auto &escape : escapePairs) {
-        if (str.front() == escape.first && str.back() == escape.second) {
-          // Drop the escape characters and trim.
-          str = str.drop_front().drop_back().trim();
-          // Don't process additional escape sequences.
-          break;
-        }
-      }
-    }
-    return str;
-  };
   // Try to process the given punctuation, properly escaping any contained
   // characters.
   auto tryProcessPunct = [&](size_t &currentPos, char punct) {
@@ -276,13 +284,13 @@ parseNextArg(StringRef options) {
   for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
     // Check for the end of the full option.
     if (argEndIt == optionsE || options[argEndIt] == ' ') {
-      argName = extractArgAndUpdateOptions(argEndIt);
+      argName = extractArgAndUpdateOptions(options, argEndIt);
       return std::make_tuple(argName, StringRef(), options);
     }
 
     // Check for the end of the name and the start of the value.
     if (options[argEndIt] == '=') {
-      argName = extractArgAndUpdateOptions(argEndIt);
+      argName = extractArgAndUpdateOptions(options, argEndIt);
       options = options.drop_front();
       break;
     }
@@ -292,7 +300,7 @@ parseNextArg(StringRef options) {
   for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
     // Handle the end of the options string.
     if (argEndIt == optionsE || options[argEndIt] == ' ') {
-      StringRef value = extractArgAndUpdateOptions(argEndIt);
+      StringRef value = extractArgAndUpdateOptions(options, argEndIt);
       return std::make_tuple(argName, value, options);
     }
 
@@ -344,7 +352,7 @@ LogicalResult detail::PassOptions::parseFromString(StringRef options,
 
 /// Print the options held by this struct in a form that can be parsed via
 /// 'parseFromString'.
-void detail::PassOptions::print(raw_ostream &os) {
+void detail::PassOptions::print(raw_ostream &os) const {
   // If there are no options, there is nothing left to do.
   if (OptionsMap.empty())
     return;

diff  --git a/mlir/test/Pass/pipeline-options-parsing.mlir b/mlir/test/Pass/pipeline-options-parsing.mlir
index 50f08241ee5cfa..b6c2b688b7cfb3 100644
--- a/mlir/test/Pass/pipeline-options-parsing.mlir
+++ b/mlir/test/Pass/pipeline-options-parsing.mlir
@@ -11,6 +11,8 @@
 // RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string="foo bar baz"})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_5 %s
 // RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string={foo bar baz}})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_5 %s
 // RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string=foo"bar"baz})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_6 %s
+// RUN: mlir-opt %s -verify-each=false '-test-options-super-pass-pipeline=super-list={{enum=zero list=1 string=foo},{enum=one list=2 string="bar"},{enum=two list=3 string={baz}}}' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_7 %s
+// RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(func.func(test-options-super-pass{list={{enum=zero list={1} string=foo },{enum=one list={2} string=bar },{enum=two list={3} string=baz }}}))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_7 %s
 
 // CHECK_ERROR_1: missing closing '}' while processing pass options
 // CHECK_ERROR_2: no such option test-option
@@ -18,9 +20,10 @@
 // CHECK_ERROR_4: 'notaninteger' value invalid for integer argument
 // CHECK_ERROR_5: for the --enum option: Cannot find option named 'invalid'!
 
-// CHECK_1: test-options-pass{enum=zero list=1,2,3,4,5 string=nested_pipeline{arg1=10 arg2=" {} " arg3=true} string-list=a,b,c,d}
-// CHECK_2: test-options-pass{enum=one list=1 string= string-list=a,b}
-// CHECK_3: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string= })))
-// CHECK_4: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string=foobar })))
-// CHECK_5: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string={foo bar baz} })))
-// CHECK_6: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list=3 string= }),func.func(test-options-pass{enum=one list=1,2,3,4 string=foo"bar"baz })))
+// CHECK_1: test-options-pass{enum=zero list={1,2,3,4,5} string=nested_pipeline{arg1=10 arg2=" {} " arg3=true} string-list={a,b,c,d}}
+// CHECK_2: test-options-pass{enum=one list={1} string= string-list={a,b}}
+// CHECK_3: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string= })))
+// CHECK_4: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string=foobar })))
+// CHECK_5: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string={foo bar baz} })))
+// CHECK_6: builtin.module(builtin.module(func.func(test-options-pass{enum=zero list={3} string= }),func.func(test-options-pass{enum=one list={1,2,3,4} string=foo"bar"baz })))
+// CHECK_7{LITERAL}: builtin.module(func.func(test-options-super-pass{list={{enum=zero list={1} string=foo },{enum=one list={2} string=bar },{enum=two list={3} string=baz }}}))

diff  --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 2762e254903245..ee32bec0c79bd4 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -54,7 +54,7 @@ struct TestOptionsPass
     : public PassWrapper<TestOptionsPass, OperationPass<func::FuncOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsPass)
 
-  enum Enum { One, Two };
+  enum Enum { Zero, One, Two };
 
   struct Options : public PassPipelineOptions<Options> {
     ListOption<int> listOption{*this, "list",
@@ -66,7 +66,15 @@ struct TestOptionsPass
     Option<Enum> enumOption{
         *this, "enum", llvm::cl::desc("Example enum option"),
         llvm::cl::values(clEnumValN(0, "zero", "Example zero value"),
-                         clEnumValN(1, "one", "Example one value"))};
+                         clEnumValN(1, "one", "Example one value"),
+                         clEnumValN(2, "two", "Example two value"))};
+
+    Options() = default;
+    Options(const Options &rhs) { *this = rhs; }
+    Options &operator=(const Options &rhs) {
+      copyOptionValuesFrom(rhs);
+      return *this;
+    }
   };
   TestOptionsPass() = default;
   TestOptionsPass(const TestOptionsPass &) : PassWrapper() {}
@@ -92,7 +100,37 @@ struct TestOptionsPass
   Option<Enum> enumOption{
       *this, "enum", llvm::cl::desc("Example enum option"),
       llvm::cl::values(clEnumValN(0, "zero", "Example zero value"),
-                       clEnumValN(1, "one", "Example one value"))};
+                       clEnumValN(1, "one", "Example one value"),
+                       clEnumValN(2, "two", "Example two value"))};
+};
+
+struct TestOptionsSuperPass
+    : public PassWrapper<TestOptionsSuperPass, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsSuperPass)
+
+  struct Options : public PassPipelineOptions<Options> {
+    ListOption<TestOptionsPass::Options> listOption{
+        *this, "super-list",
+        llvm::cl::desc("Example list of PassPipelineOptions option")};
+
+    Options() = default;
+  };
+
+  TestOptionsSuperPass() = default;
+  TestOptionsSuperPass(const TestOptionsSuperPass &) : PassWrapper() {}
+  TestOptionsSuperPass(const Options &options) {
+    listOption = options.listOption;
+  }
+
+  void runOnOperation() final {}
+  StringRef getArgument() const final { return "test-options-super-pass"; }
+  StringRef getDescription() const final {
+    return "Test options of options parsing capabilities";
+  }
+
+  ListOption<TestOptionsPass::Options> listOption{
+      *this, "list",
+      llvm::cl::desc("Example list of PassPipelineOptions option")};
 };
 
 /// A test pass that always aborts to enable testing the crash recovery
@@ -220,6 +258,7 @@ static void testNestedPipelineTextual(OpPassManager &pm) {
 namespace mlir {
 void registerPassManagerTestPass() {
   PassRegistration<TestOptionsPass>();
+  PassRegistration<TestOptionsSuperPass>();
 
   PassRegistration<TestModulePass>();
 
@@ -248,5 +287,14 @@ void registerPassManagerTestPass() {
           [](OpPassManager &pm, const TestOptionsPass::Options &options) {
             pm.addPass(std::make_unique<TestOptionsPass>(options));
           });
+
+  PassPipelineRegistration<TestOptionsSuperPass::Options>
+      registerOptionsSuperPassPipeline(
+          "test-options-super-pass-pipeline",
+          "Parses options of PassPipelineOptions using pass pipeline "
+          "registration",
+          [](OpPassManager &pm, const TestOptionsSuperPass::Options &options) {
+            pm.addPass(std::make_unique<TestOptionsSuperPass>(options));
+          });
 }
 } // namespace mlir


        


More information about the Mlir-commits mailing list