[Mlir-commits] [mlir] 6edef13 - [mlir:PassOption] Rework ListOption parsing and add support for std::vector/SmallVector options

River Riddle llvmlistbot at llvm.org
Sat Apr 2 01:07:47 PDT 2022


Author: River Riddle
Date: 2022-04-02T00:45:11-07:00
New Revision: 6edef1356921d9cad1a8cd6169207450741536a6

URL: https://github.com/llvm/llvm-project/commit/6edef1356921d9cad1a8cd6169207450741536a6
DIFF: https://github.com/llvm/llvm-project/commit/6edef1356921d9cad1a8cd6169207450741536a6.diff

LOG: [mlir:PassOption] Rework ListOption parsing and add support for std::vector/SmallVector options

ListOption currently uses llvm::cl::list under the hood, but the usages
of ListOption are generally a tad different from llvm::cl::list. This
commit codifies this by making ListOption implicitly comma separated,
and removes the explicit flag set for all of the current list options.
The new parsing for comma separation of ListOption also adds in support
for skipping over delimited sub-ranges (i.e. {}, [], (), "", ''). This
more easily supports nested options that use those as part of the
format, and this constraint (balanced delimiters) is already codified
in the syntax of pass pipelines.

See https://discourse.llvm.org/t/list-of-lists-pass-option/5950 for
related discussion

Differential Revision: https://reviews.llvm.org/D122879

Added: 
    

Modified: 
    mlir/docs/PassManagement.md
    mlir/docs/PatternRewriter.md
    mlir/include/mlir/Dialect/Affine/Passes.td
    mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/SCF/Passes.td
    mlir/include/mlir/Pass/PassOptions.h
    mlir/include/mlir/Reducer/Passes.td
    mlir/include/mlir/Rewrite/PassUtil.td
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Pass/PassRegistry.cpp
    mlir/test/Dialect/Linalg/hoist-padding.mlir
    mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
    mlir/test/lib/Dialect/SCF/TestLoopParametricTiling.cpp
    mlir/test/lib/Dialect/SPIRV/TestEntryPointAbi.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
    mlir/test/lib/IR/TestDiagnostics.cpp
    mlir/test/lib/Pass/TestDynamicPipeline.cpp
    mlir/test/lib/Pass/TestPassManager.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index 61418d438f2f8..aec01f6231813 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -431,9 +431,12 @@ components are integrated with the dynamic pipeline being executed.
 MLIR provides a builtin mechanism for passes to specify options that configure
 its behavior. These options are parsed at pass construction time independently
 for each instance of the pass. Options are defined using the `Option<>` and
-`ListOption<>` classes, and follow the
+`ListOption<>` classes, and generally follow the
 [LLVM command line](https://llvm.org/docs/CommandLine.html) flag definition
-rules. See below for a few examples:
+rules. One major distinction from the LLVM command line functionality is that
+all `ListOption`s are comma-separated, and delimited sub-ranges within individual
+elements of the list may contain commas that are not treated as separators for the
+top-level list.
 
 ```c++
 struct MyPass ... {
@@ -445,8 +448,7 @@ struct MyPass ... {
   /// Any parameters after the description are forwarded to llvm::cl::list and
   /// llvm::cl::opt respectively.
   Option<int> exampleOption{*this, "flag-name", llvm::cl::desc("...")};
-  ListOption<int> exampleListOption{*this, "list-flag-name",
-                                    llvm::cl::desc("...")};
+  ListOption<int> exampleListOption{*this, "list-flag-name", llvm::cl::desc("...")};
 };
 ```
 
@@ -705,8 +707,7 @@ struct MyPass : PassWrapper<MyPass, OperationPass<ModuleOp>> {
       llvm::cl::desc("An example option"), llvm::cl::init(true)};
   ListOption<int64_t> listOption{
       *this, "example-list",
-      llvm::cl::desc("An example list option"), llvm::cl::ZeroOrMore,
-      llvm::cl::MiscFlags::CommaSeparated};
+      llvm::cl::desc("An example list option"), llvm::cl::ZeroOrMore};
 
   // Specify any statistics.
   Statistic statistic{this, "example-statistic", "An example statistic"};
@@ -742,8 +743,7 @@ def MyPass : Pass<"my-pass", "ModuleOp"> {
     Option<"option", "example-option", "bool", /*default=*/"true",
            "An example option">,
     ListOption<"listOption", "example-list", "int64_t",
-               "An example list option",
-               "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
+               "An example list option", "llvm::cl::ZeroOrMore">
   ];
 
   // Specify any statistics.
@@ -879,8 +879,7 @@ The `ListOption` class takes the following fields:
 def MyPass : Pass<"my-pass"> {
   let options = [
     ListOption<"listOption", "example-list", "int64_t",
-               "An example list option",
-               "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
+               "An example list option", "llvm::cl::ZeroOrMore">
   ];
 }
 ```

diff  --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 7b0db967c35b7..1eb594da9e291 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -439,12 +439,10 @@ below:
 
 ```tablegen
 ListOption<"disabledPatterns", "disable-patterns", "std::string",
-           "Labels of patterns that should be filtered out during application",
-           "llvm::cl::MiscFlags::CommaSeparated">,
+           "Labels of patterns that should be filtered out during application">,
 ListOption<"enabledPatterns", "enable-patterns", "std::string",
            "Labels of patterns that should be used during application, all "
-           "other patterns are filtered out",
-           "llvm::cl::MiscFlags::CommaSeparated">,
+           "other patterns are filtered out">,
 ```
 
 These options may be used to provide filtering behavior when constructing any

diff  --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 50783adb8b4de..e91991213eddb 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -348,7 +348,7 @@ def AffineVectorize : Pass<"affine-super-vectorize", "FuncOp"> {
   let options = [
     ListOption<"vectorSizes", "virtual-vector-size", "int64_t",
                "Specify an n-D virtual vector size for vectorization",
-               "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
+               "llvm::cl::ZeroOrMore">,
     // Optionally, the fixed mapping from loop to fastest varying MemRef
     // dimension for all the MemRefs within a loop pattern:
     //   the index represents the loop depth, the value represents the k^th
@@ -359,7 +359,7 @@ def AffineVectorize : Pass<"affine-super-vectorize", "FuncOp"> {
                "Specify a 1-D, 2-D or 3-D pattern of fastest varying memory "
                "dimensions to match. See defaultPatterns in Vectorize.cpp for "
                "a description and examples. This is used for testing purposes",
-               "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
+               "llvm::cl::ZeroOrMore">,
     Option<"vectorizeReductions", "vectorize-reductions", "bool",
            /*default=*/"false",
            "Vectorize known reductions expressed via iter_args. "

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index ee7579b1a3edb..6d56f47b8a375 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -215,8 +215,7 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
            "Specify if buffers should be deallocated. For compatibility with "
            "core bufferization passes.">,
     ListOption<"dialectFilter", "dialect-filter", "std::string",
-               "Restrict bufferization to ops from these dialects.",
-               "llvm::cl::MiscFlags::CommaSeparated">,
+               "Restrict bufferization to ops from these dialects.">,
     Option<"fullyDynamicLayoutMaps", "fully-dynamic-layout-maps", "bool",
            /*default=*/"true",
            "Generate MemRef types with dynamic offset+strides by default.">,

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index b84b99722fa7a..05184ce3635e9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -194,7 +194,7 @@ def LinalgTiling : Pass<"linalg-tile", "FuncOp"> {
   ];
   let options = [
     ListOption<"tileSizes", "tile-sizes", "int64_t", "Tile sizes",
-               "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
+               "llvm::cl::ZeroOrMore">,
     Option<"loopType", "loop-type", "std::string", /*default=*/"\"for\"",
            "Specify the type of loops to generate: for, parallel">
   ];

diff  --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td
index b182be54eef80..72f7ff75cc7b1 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Passes.td
@@ -55,14 +55,11 @@ def SCFParallelLoopCollapsing : Pass<"scf-parallel-loop-collapsing"> {
   let constructor = "mlir::createParallelLoopCollapsingPass()";
   let options = [
     ListOption<"clCollapsedIndices0", "collapsed-indices-0", "unsigned",
-               "Which loop indices to combine 0th loop index",
-               "llvm::cl::MiscFlags::CommaSeparated">,
+               "Which loop indices to combine 0th loop index">,
     ListOption<"clCollapsedIndices1", "collapsed-indices-1", "unsigned",
-               "Which loop indices to combine into the position 1 loop index",
-               "llvm::cl::MiscFlags::CommaSeparated">,
+               "Which loop indices to combine into the position 1 loop index">,
     ListOption<"clCollapsedIndices2", "collapsed-indices-2", "unsigned",
-               "Which loop indices to combine into the position 2 loop index",
-               "llvm::cl::MiscFlags::CommaSeparated">,
+               "Which loop indices to combine into the position 2 loop index">,
   ];
 }
 
@@ -77,8 +74,7 @@ def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling", "FuncOp"> {
   let constructor = "mlir::createParallelLoopTilingPass()";
   let options = [
     ListOption<"tileSizes", "parallel-loop-tile-sizes", "int64_t",
-               "Factors to tile parallel loops by",
-               "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
+               "Factors to tile parallel loops by", "llvm::cl::ZeroOrMore">,
     Option<"noMinMaxBounds", "no-min-max-bounds", "bool",
            /*default=*/"false",
            "Perform tiling with fixed upper bound with inbound check "

diff  --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h
index 13c44345bf23d..53e0060333dce 100644
--- a/mlir/include/mlir/Pass/PassOptions.h
+++ b/mlir/include/mlir/Pass/PassOptions.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/FunctionExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Compiler.h"
@@ -23,6 +24,55 @@
 
 namespace mlir {
 namespace detail {
+namespace pass_options {
+/// Parse a string containing a list of comma-delimited elements, invoking the
+/// given parser for each sub-element and passing them to the provided
+/// element-append functor.
+LogicalResult
+parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName,
+                        StringRef optionStr,
+                        function_ref<LogicalResult(StringRef)> elementParseFn);
+template <typename ElementParser, typename ElementAppendFn>
+LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName,
+                                      StringRef optionStr,
+                                      ElementParser &elementParser,
+                                      ElementAppendFn &&appendFn) {
+  return parseCommaSeparatedList(
+      opt, argName, optionStr, [&](StringRef valueStr) {
+        typename ElementParser::parser_data_type value = {};
+        if (elementParser.parse(opt, argName, valueStr, value))
+          return failure();
+        appendFn(value);
+        return success();
+      });
+}
+
+/// Trait used to detect if a type has a operator<< method.
+template <typename T>
+using has_stream_operator_trait =
+    decltype(std::declval<raw_ostream &>() << std::declval<T>());
+template <typename T>
+using has_stream_operator = llvm::is_detected<has_stream_operator_trait, T>;
+
+/// Utility methods for printing option values.
+template <typename ParserT>
+static void printOptionValue(raw_ostream &os, const bool &value) {
+  os << (value ? StringRef("true") : StringRef("false"));
+}
+template <typename ParserT, typename DataT>
+static std::enable_if_t<has_stream_operator<DataT>::value>
+printOptionValue(raw_ostream &os, const DataT &value) {
+  os << value;
+}
+template <typename ParserT, typename DataT>
+static std::enable_if_t<!has_stream_operator<DataT>::value>
+printOptionValue(raw_ostream &os, const DataT &value) {
+  // If the value can't be streamed, fallback to checking for a print in the
+  // parser.
+  ParserT::print(os, value);
+}
+} // namespace pass_options
+
 /// Base container class and manager for all pass options.
 class PassOptions : protected llvm::cl::SubCommand {
 private:
@@ -85,11 +135,7 @@ class PassOptions : protected llvm::cl::SubCommand {
   }
   template <typename DataT, typename ParserT>
   static void printValue(raw_ostream &os, ParserT &parser, const DataT &value) {
-    os << value;
-  }
-  template <typename ParserT>
-  static void printValue(raw_ostream &os, ParserT &parser, const bool &value) {
-    os << (value ? StringRef("true") : StringRef("false"));
+    detail::pass_options::printOptionValue<ParserT>(os, value);
   }
 
 public:
@@ -149,22 +195,27 @@ class PassOptions : protected llvm::cl::SubCommand {
   };
 
   /// This class represents a specific pass option that contains a list of
-  /// values of the provided data type.
+  /// values of the provided data type. The elements within the textual form of
+  /// this option are parsed assuming they are comma-separated. Delimited
+  /// sub-ranges within individual elements of the list may contain commas that
+  /// are not treated as separators for the top-level list.
   template <typename DataType, typename OptionParser = OptionParser<DataType>>
   class ListOption
       : public llvm::cl::list<DataType, /*StorageClass=*/bool, OptionParser>,
         public OptionBase {
   public:
     template <typename... Args>
-    ListOption(PassOptions &parent, StringRef arg, Args &&... args)
+    ListOption(PassOptions &parent, StringRef arg, Args &&...args)
         : llvm::cl::list<DataType, /*StorageClass=*/bool, OptionParser>(
-              arg, llvm::cl::sub(parent), std::forward<Args>(args)...) {
+              arg, llvm::cl::sub(parent), std::forward<Args>(args)...),
+          elementParser(*this) {
       assert(!this->isPositional() && !this->isSink() &&
              "sink and positional options are not supported");
+      assert(!(this->getMiscFlags() & llvm::cl::MiscFlags::CommaSeparated) &&
+             "ListOption is implicitly comma separated, specifying "
+             "CommaSeparated is extraneous");
       parent.options.push_back(this);
-
-      // Set a callback to track if this option has a value.
-      this->setCallback([this](const auto &) { this->optHasValue = true; });
+      elementParser.initialize();
     }
     ~ListOption() override = default;
     ListOption<DataType, OptionParser> &
@@ -174,6 +225,14 @@ class PassOptions : protected llvm::cl::SubCommand {
       return *this;
     }
 
+    bool handleOccurrence(unsigned pos, StringRef argName,
+                          StringRef arg) override {
+      this->optHasValue = true;
+      return failed(detail::pass_options::parseCommaSeparatedList(
+          *this, argName, arg, elementParser,
+          [&](const DataType &value) { this->addValue(value); }));
+    }
+
     /// Allow assigning from an ArrayRef.
     ListOption<DataType, OptionParser> &operator=(ArrayRef<DataType> values) {
       ((std::vector<DataType> &)*this).assign(values.begin(), values.end());
@@ -211,6 +270,9 @@ class PassOptions : protected llvm::cl::SubCommand {
     void copyValueFrom(const OptionBase &other) final {
       *this = static_cast<const ListOption<DataType, OptionParser> &>(other);
     }
+
+    /// The parser to use for parsing the list elements.
+    OptionParser elementParser;
   };
 
   PassOptions() = default;
@@ -255,9 +317,7 @@ class PassOptions : protected llvm::cl::SubCommand {
 /// Usage:
 ///
 /// struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> {
-///   ListOption<int> someListFlag{
-///        *this, "flag-name", llvm::cl::MiscFlags::CommaSeparated,
-///        llvm::cl::desc("...")};
+///   ListOption<int> someListFlag{*this, "flag-name", llvm::cl::desc("...")};
 /// };
 template <typename T> class PassPipelineOptions : public detail::PassOptions {
 public:
@@ -278,5 +338,77 @@ struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> {
 
 } // namespace mlir
 
+//===----------------------------------------------------------------------===//
+// MLIR Options
+//===----------------------------------------------------------------------===//
+
+namespace llvm {
+namespace cl {
+//===----------------------------------------------------------------------===//
+// std::vector+SmallVector
+
+namespace detail {
+template <typename VectorT, typename ElementT>
+class VectorParserBase : public basic_parser_impl {
+public:
+  VectorParserBase(Option &opt) : basic_parser_impl(opt), elementParser(opt) {}
+
+  using parser_data_type = VectorT;
+
+  bool parse(Option &opt, StringRef argName, StringRef arg,
+             parser_data_type &vector) {
+    if (!arg.consume_front("[") || !arg.consume_back("]")) {
+      return opt.error("expected vector option to be wrapped with '[]'",
+                       argName);
+    }
+
+    return failed(mlir::detail::pass_options::parseCommaSeparatedList(
+        opt, argName, arg, elementParser,
+        [&](const ElementT &value) { vector.push_back(value); }));
+  }
+
+  static void print(raw_ostream &os, const VectorT &vector) {
+    llvm::interleave(
+        vector, os,
+        [&](const ElementT &value) {
+          mlir::detail::pass_options::printOptionValue<
+              llvm::cl::parser<ElementT>>(os, value);
+        },
+        ",");
+  }
+
+  void printOptionInfo(const Option &opt, size_t globalWidth) const {
+    // Add the `vector<>` qualifier to the option info.
+    outs() << "  --" << opt.ArgStr;
+    outs() << "=<vector<" << elementParser.getValueName() << ">>";
+    Option::printHelpStr(opt.HelpStr, globalWidth, getOptionWidth(opt));
+  }
+
+  size_t getOptionWidth(const Option &opt) const {
+    // Add the `vector<>` qualifier to the option width.
+    StringRef vectorExt("vector<>");
+    return elementParser.getOptionWidth(opt) + vectorExt.size();
+  }
+
+private:
+  llvm::cl::parser<ElementT> elementParser;
+};
+} // namespace detail
+
+template <typename T>
+class parser<std::vector<T>>
+    : public detail::VectorParserBase<std::vector<T>, T> {
+public:
+  parser(Option &opt) : detail::VectorParserBase<std::vector<T>, T>(opt) {}
+};
+template <typename T, unsigned N>
+class parser<SmallVector<T, N>>
+    : public detail::VectorParserBase<SmallVector<T, N>, T> {
+public:
+  parser(Option &opt) : detail::VectorParserBase<SmallVector<T, N>, T>(opt) {}
+};
+} // end namespace cl
+} // end namespace llvm
+
 #endif // MLIR_PASS_PASSOPTIONS_H_
 

diff  --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td
index 7fc4ba1643d74..acaf90f38e0fe 100644
--- a/mlir/include/mlir/Reducer/Passes.td
+++ b/mlir/include/mlir/Reducer/Passes.td
@@ -20,8 +20,7 @@ def CommonReductionPassOptions {
     Option<"testerName", "test", "std::string", /* default */"",
            "The location of the tester which tests the file interestingness">,
     ListOption<"testerArgs", "test-arg", "std::string",
-               "arguments of the tester",
-               "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
+               "arguments of the tester", "llvm::cl::ZeroOrMore">,
   ];
 }
 

diff  --git a/mlir/include/mlir/Rewrite/PassUtil.td b/mlir/include/mlir/Rewrite/PassUtil.td
index 8fc548b11336a..10a039590bfa3 100644
--- a/mlir/include/mlir/Rewrite/PassUtil.td
+++ b/mlir/include/mlir/Rewrite/PassUtil.td
@@ -24,12 +24,10 @@ def RewritePassUtils {
     // created.
     ListOption<"disabledPatterns", "disable-patterns", "std::string",
                "Labels of patterns that should be filtered out during"
-               " application",
-               "llvm::cl::MiscFlags::CommaSeparated">,
+               " application">,
     ListOption<"enabledPatterns", "enable-patterns", "std::string",
                "Labels of patterns that should be used during"
-               " application, all other patterns are filtered out",
-               "llvm::cl::MiscFlags::CommaSeparated">,
+               " application, all other patterns are filtered out">,
   ];
 }
 

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 54f4094c7b4e5..cda2997c5e621 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -85,8 +85,7 @@ def Inliner : Pass<"inline"> {
            /*default=*/"", "The default optimizer pipeline used for callables">,
     ListOption<"opPipelineStrs", "op-pipelines", "std::string",
                "Callable operation specific optimizer pipelines (in the form "
-               "of `dialect.op(pipeline)`)",
-               "llvm::cl::MiscFlags::CommaSeparated">,
+               "of `dialect.op(pipeline)`)">,
     Option<"maxInliningIterations", "max-iterations", "unsigned",
            /*default=*/"4",
            "Maximum number of iterations when inlining within an SCC">,
@@ -226,8 +225,7 @@ def SymbolPrivatize : Pass<"symbol-privatize"> {
   }];
   let options = [
     ListOption<"exclude", "exclude", "std::string",
-       "Comma separated list of symbols that should not be marked private",
-       "llvm::cl::MiscFlags::CommaSeparated">
+       "Comma separated list of symbols that should not be marked private">
   ];
   let constructor = "mlir::createSymbolPrivatizePass()";
 }

diff  --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 6f3c861aa59ab..5088ae244c729 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -142,6 +142,43 @@ const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
 // PassOptions
 //===----------------------------------------------------------------------===//
 
+LogicalResult detail::pass_options::parseCommaSeparatedList(
+    llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
+    function_ref<LogicalResult(StringRef)> elementParseFn) {
+  // Functor used for finding a character in a string, and skipping over
+  // various "range" characters.
+  llvm::unique_function<size_t(StringRef, size_t, char)> findChar =
+      [&](StringRef str, size_t index, char c) -> size_t {
+    for (size_t i = index, e = str.size(); i < e; ++i) {
+      if (str[i] == c)
+        return i;
+      // Check for various range characters.
+      if (str[i] == '{')
+        i = findChar(str, i + 1, '}');
+      else if (str[i] == '(')
+        i = findChar(str, i + 1, ')');
+      else if (str[i] == '[')
+        i = findChar(str, i + 1, ']');
+      else if (str[i] == '\"')
+        i = str.find_first_of('\"', i + 1);
+      else if (str[i] == '\'')
+        i = str.find_first_of('\'', i + 1);
+    }
+    return StringRef::npos;
+  };
+
+  size_t nextElePos = findChar(optionStr, 0, ',');
+  while (nextElePos != StringRef::npos) {
+    // Process the portion before the comma.
+    if (failed(elementParseFn(optionStr.substr(0, nextElePos))))
+      return failure();
+
+    optionStr = optionStr.substr(nextElePos + 1);
+    nextElePos = findChar(optionStr, 0, ',');
+  }
+  return elementParseFn(optionStr.substr(0, nextElePos));
+}
+
 /// Out of line virtual function to provide home for the class.
 void detail::PassOptions::OptionBase::anchor() {}
 

diff  --git a/mlir/test/Dialect/Linalg/hoist-padding.mlir b/mlir/test/Dialect/Linalg/hoist-padding.mlir
index 15e6e511e970e..8dc5297bfec2e 100644
--- a/mlir/test/Dialect/Linalg/hoist-padding.mlir
+++ b/mlir/test/Dialect/Linalg/hoist-padding.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matvec pad hoist-paddings=1,1,0 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=MATVEC
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matvec pad hoist-paddings=1,1,0 transpose-paddings=1:0,0,0 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=TRANSP
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matvec pad hoist-paddings=1,1,0 transpose-paddings=[1,0],[0],[0] run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=TRANSP
 // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad hoist-paddings=1,2,1 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=MATMUL
 
 //  MATVEC-DAG: #[[DIV4:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 4)>

diff  --git a/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp
index da3ca784f7e84..0c65fc843578c 100644
--- a/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp
@@ -38,7 +38,7 @@ struct TestLoopPermutation
   /// transformed nest (with i going from outermost to innermost).
   ListOption<unsigned> permList{*this, "permutation-map",
                                 llvm::cl::desc("Specify the loop permutation"),
-                                llvm::cl::OneOrMore, llvm::cl::CommaSeparated};
+                                llvm::cl::OneOrMore};
 };
 
 } // namespace

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
index 09e93d60116ef..fa6a478a05f53 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
@@ -67,10 +67,9 @@ struct TestLinalgCodegenStrategy
       llvm::cl::desc("Fuse the producers after tiling the root op."),
       llvm::cl::init(false)};
   ListOption<int64_t> tileSizes{*this, "tile-sizes",
-                                llvm::cl::MiscFlags::CommaSeparated,
                                 llvm::cl::desc("Specifies the tile sizes.")};
   ListOption<int64_t> tileInterchange{
-      *this, "tile-interchange", llvm::cl::MiscFlags::CommaSeparated,
+      *this, "tile-interchange",
       llvm::cl::desc("Specifies the tile interchange.")};
 
   Option<bool> promote{
@@ -82,7 +81,7 @@ struct TestLinalgCodegenStrategy
       llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."),
       llvm::cl::init(false)};
   ListOption<int64_t> registerTileSizes{
-      *this, "register-tile-sizes", llvm::cl::MiscFlags::CommaSeparated,
+      *this, "register-tile-sizes",
       llvm::cl::desc(
           "Specifies the size of the register tile that will be used "
           " to vectorize")};
@@ -100,33 +99,33 @@ struct TestLinalgCodegenStrategy
   ListOption<std::string> paddingValues{
       *this, "padding-values",
       llvm::cl::desc("Operand padding values parsed by the attribute parser."),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+      llvm::cl::ZeroOrMore};
   ListOption<int64_t> paddingDimensions{
       *this, "padding-dimensions",
       llvm::cl::desc("Operation iterator dimensions to pad."),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-  ListOption<int64_t> packPaddings{
-      *this, "pack-paddings", llvm::cl::desc("Operand packing flags."),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-  ListOption<int64_t> hoistPaddings{
-      *this, "hoist-paddings", llvm::cl::desc("Operand hoisting depths."),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-  ListOption<std::string> transposePaddings{
+      llvm::cl::ZeroOrMore};
+  ListOption<int64_t> packPaddings{*this, "pack-paddings",
+                                   llvm::cl::desc("Operand packing flags."),
+                                   llvm::cl::ZeroOrMore};
+  ListOption<int64_t> hoistPaddings{*this, "hoist-paddings",
+                                    llvm::cl::desc("Operand hoisting depths."),
+                                    llvm::cl::ZeroOrMore};
+  ListOption<SmallVector<int64_t>> transposePaddings{
       *this, "transpose-paddings",
       llvm::cl::desc(
           "Transpose paddings. Specify a operand dimension interchange "
           "using the following format:\n"
-          "-transpose-paddings=1:0:2,0:1,0:1\n"
+          "-transpose-paddings=[1,0,2],[0,1],[0,1]\n"
           "It defines the interchange [1, 0, 2] for operand one and "
           "the interchange [0, 1] (no transpose) for the remaining operands."
           "All interchange vectors have to be permuations matching the "
           "operand rank."),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+      llvm::cl::ZeroOrMore};
   Option<bool> generalize{*this, "generalize",
                           llvm::cl::desc("Generalize named operations."),
                           llvm::cl::init(false)};
   ListOption<int64_t> iteratorInterchange{
-      *this, "iterator-interchange", llvm::cl::MiscFlags::CommaSeparated,
+      *this, "iterator-interchange",
       llvm::cl::desc("Specifies the iterator interchange.")};
   Option<bool> decompose{
       *this, "decompose",
@@ -259,16 +258,6 @@ void TestLinalgCodegenStrategy::runOnOperation() {
   }
 
   // Parse the transpose vectors.
-  SmallVector<SmallVector<int64_t>> transposePaddingVectors;
-  for (const std::string &transposePadding : transposePaddings) {
-    SmallVector<int64_t> transposeVector = {};
-    SmallVector<StringRef> tokens;
-    StringRef(transposePadding).split(tokens, ':');
-    for (StringRef token : tokens)
-      transposeVector.push_back(std::stoi(token.str()));
-    transposePaddingVectors.push_back(transposeVector);
-  }
-
   LinalgPaddingOptions paddingOptions;
   paddingOptions.setPaddingValues(paddingValueAttributes);
   paddingOptions.setPaddingDimensions(
@@ -277,7 +266,7 @@ void TestLinalgCodegenStrategy::runOnOperation() {
       SmallVector<bool>{packPaddings.begin(), packPaddings.end()});
   paddingOptions.setHoistPaddings(
       SmallVector<int64_t>{hoistPaddings.begin(), hoistPaddings.end()});
-  paddingOptions.setTransposePaddings(transposePaddingVectors);
+  paddingOptions.setTransposePaddings(transposePaddings);
 
   vector::VectorContractLowering vectorContractLowering =
       llvm::StringSwitch<vector::VectorContractLowering>(

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 0618eb2f19d82..271837710ba4e 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -270,9 +270,9 @@ struct TestLinalgTileAndFuseSequencePass
       const TestLinalgTileAndFuseSequencePass &pass)
       : PassWrapper(pass){};
 
-  ListOption<int64_t> tileSizes{
-      *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+  ListOption<int64_t> tileSizes{*this, "tile-sizes",
+                                llvm::cl::desc("Tile sizes to use for ops"),
+                                llvm::cl::ZeroOrMore};
 
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 06d1d81a38592..23df4439cab0a 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -118,11 +118,11 @@ struct TestLinalgTransforms
   ListOption<int64_t> peeledLoops{
       *this, "peeled-loops",
       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+      llvm::cl::ZeroOrMore};
   ListOption<int64_t> tileSizes{
       *this, "tile-sizes",
       llvm::cl::desc("Linalg tile sizes for test-tile-pattern"),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+      llvm::cl::ZeroOrMore};
   Option<bool> skipPartial{
       *this, "skip-partial",
       llvm::cl::desc("Skip loops inside partial iterations during peeling"),

diff  --git a/mlir/test/lib/Dialect/SCF/TestLoopParametricTiling.cpp b/mlir/test/lib/Dialect/SCF/TestLoopParametricTiling.cpp
index 71854662f1ab7..d34ce2421eac2 100644
--- a/mlir/test/lib/Dialect/SCF/TestLoopParametricTiling.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestLoopParametricTiling.cpp
@@ -47,7 +47,7 @@ class SimpleParametricLoopTilingPass
   }
 
   ListOption<int64_t> sizes{
-      *this, "test-outer-loop-sizes", llvm::cl::MiscFlags::CommaSeparated,
+      *this, "test-outer-loop-sizes",
       llvm::cl::desc(
           "fixed number of iterations that the outer loops should have")};
 };

diff  --git a/mlir/test/lib/Dialect/SPIRV/TestEntryPointAbi.cpp b/mlir/test/lib/Dialect/SPIRV/TestEntryPointAbi.cpp
index b12289a6d93a8..ab127c8918b77 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestEntryPointAbi.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestEntryPointAbi.cpp
@@ -41,7 +41,7 @@ class TestSpirvEntryPointABIPass
           "Workgroup size to use for all gpu.func kernels in the module, "
           "specified with x-dimension first, y-dimension next and z-dimension "
           "last. Unspecified dimensions will be set to 1"),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+      llvm::cl::ZeroOrMore};
 };
 } // namespace
 

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 03fc3a88c1d91..63f8060533fa7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -330,7 +330,7 @@ struct TestVectorDistributePatterns
     registry.insert<AffineDialect>();
   }
   ListOption<int32_t> multiplicity{
-      *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
+      *this, "distribution-multiplicity",
       llvm::cl::desc("Set the multiplicity used for distributing vector")};
 
   void runOnOperation() override {

diff  --git a/mlir/test/lib/IR/TestDiagnostics.cpp b/mlir/test/lib/IR/TestDiagnostics.cpp
index c56d4488ae6d0..123a2aa46442a 100644
--- a/mlir/test/lib/IR/TestDiagnostics.cpp
+++ b/mlir/test/lib/IR/TestDiagnostics.cpp
@@ -54,7 +54,7 @@ struct TestDiagnosticFilterPass
   }
 
   ListOption<std::string> filters{
-      *this, "filters", llvm::cl::MiscFlags::CommaSeparated,
+      *this, "filters",
       llvm::cl::desc("Specifies the diagnostic file name filters.")};
 };
 

diff  --git a/mlir/test/lib/Pass/TestDynamicPipeline.cpp b/mlir/test/lib/Pass/TestDynamicPipeline.cpp
index 90b88202f7e54..114cf8c3b47b5 100644
--- a/mlir/test/lib/Pass/TestDynamicPipeline.cpp
+++ b/mlir/test/lib/Pass/TestDynamicPipeline.cpp
@@ -99,7 +99,7 @@ class TestDynamicPipelinePass
       llvm::cl::desc("The pipeline description that "
                      "will run on the filtered function.")};
   ListOption<std::string> opNames{
-      *this, "op-name", llvm::cl::MiscFlags::CommaSeparated,
+      *this, "op-name",
       llvm::cl::desc("List of function name to apply the pipeline to")};
 };
 } // namespace

diff  --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 260bf0bf0a1b6..cd0e4615cfa1d 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -48,11 +48,9 @@ class TestOptionsPass
 public:
   struct Options : public PassPipelineOptions<Options> {
     ListOption<int> listOption{*this, "list",
-                               llvm::cl::MiscFlags::CommaSeparated,
                                llvm::cl::desc("Example list option")};
     ListOption<std::string> stringListOption{
-        *this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
-        llvm::cl::desc("Example string list option")};
+        *this, "string-list", llvm::cl::desc("Example string list option")};
     Option<std::string> stringOption{*this, "string",
                                      llvm::cl::desc("Example string option")};
   };
@@ -70,11 +68,10 @@ class TestOptionsPass
     return "Test options parsing capabilities";
   }
 
-  ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
+  ListOption<int> listOption{*this, "list",
                              llvm::cl::desc("Example list option")};
   ListOption<std::string> stringListOption{
-      *this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
-      llvm::cl::desc("Example string list option")};
+      *this, "string-list", llvm::cl::desc("Example string list option")};
   Option<std::string> stringOption{*this, "string",
                                    llvm::cl::desc("Example string option")};
 };


        


More information about the Mlir-commits mailing list