[Mlir-commits] [mlir] 08d7377 - [mlir] Enable DRR variadic operand matching

Logan Chien llvmlistbot at llvm.org
Mon Aug 28 14:14:14 PDT 2023


Author: Logan Chien
Date: 2023-08-28T14:11:32-07:00
New Revision: 08d7377b67358496a409080fac22f3f7c077fb63

URL: https://github.com/llvm/llvm-project/commit/08d7377b67358496a409080fac22f3f7c077fb63
DIFF: https://github.com/llvm/llvm-project/commit/08d7377b67358496a409080fac22f3f7c077fb63.diff

LOG: [mlir] Enable DRR variadic operand matching

This commit enables DRR rewriter to match a fixed number of sub-operands
as a variadic operand.

Differential Review: https://reviews.llvm.org/D157359

Added: 
    

Modified: 
    mlir/docs/DeclarativeRewrites.md
    mlir/include/mlir/IR/PatternBase.td
    mlir/include/mlir/TableGen/Pattern.h
    mlir/lib/TableGen/Pattern.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index dd996baf3cd957..2ae99f4ea32b3e 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -647,6 +647,50 @@ correspond to multiple actual values.
 
 [TODO]
 
+#### Match variadic operand
+
+Use the `variadic` DAG node to match a variadic operand with a fixed number of
+actual sub-operands.
+
+For example, assume that `ConcatenateOp` is an operation with a variadic
+operand:
+
+```tablegen
+def ConcatenateOp : TEST_Op<"concatenate"> {
+  let arguments = (ins
+    Variadic<AnyTensor>:$inputs,
+    I32Attr:$axis
+  );
+
+  let results = (outs
+    AnyTensor$output
+  );
+}
+```
+
+We can match `ConcatenateOp` with exactly 2 actual operands with:
+
+```tablegen
+def : Pat<(ConcatenateOp (variadic $input0, $input1), $axis),
+          ...>;
+```
+
+The variadic sub-operands can be sub-DAGs to be matched:
+
+```tablegen
+def : Pat<(ConcatenateOp (variadic (SomeOp $a), (AnotherOp $b, $c)), $axis),
+          (OtherOp $a, $b, $c)>;
+```
+
+The variadic DAG can be bound to a symbol, which refers to the full
+`operand_range`:
+
+```tablegen
+def : Pat<(ConcatenateOp (variadic:$inputs $input0, $input1),
+                         ConstantAttr<I32Attr, "0">),
+          (VStackOp $inputs)>;
+```
+
 ### Supplying additional constraints
 
 Constraints can be placed on op arguments when matching. But sometimes we need

diff  --git a/mlir/include/mlir/IR/PatternBase.td b/mlir/include/mlir/IR/PatternBase.td
index 919fb884adb0e9..33f8e2ac316d5c 100644
--- a/mlir/include/mlir/IR/PatternBase.td
+++ b/mlir/include/mlir/IR/PatternBase.td
@@ -218,6 +218,22 @@ def returnType;
 // `either` while pattern matching.
 def either;
 
+// Directive used to match variadic operands. This directive only matches if
+// the variadic operand has the same length as the specified formal
+// sub-dags.
+//
+// ```
+// (VariadicOp (variadic:$input1 $input1a, $input1b),
+//             (variadic:$input2 $input2a, $input2b, $input2c),
+//             $attr1, $attr2)
+// ```
+//
+// The pattern above only matches if the `$input1` operand is of length 2,
+// `$input2` is of length 3, and all sub-dags match respectively. The `$input1`
+// symbol denotes the full variadic operand range. The `$input1a` symbol
+// denotes the first operand in the variadic sub-operands.
+def variadic;
+
 //===----------------------------------------------------------------------===//
 // Common value constraints
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 4511ba7dd833ef..80f38fdeffee07 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -22,6 +22,7 @@
 #include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringSet.h"
 
+#include <optional>
 #include <unordered_map>
 
 namespace llvm {
@@ -189,6 +190,9 @@ class DagNode {
   // Returns whether this DAG is an `either` specifier.
   bool isEither() const;
 
+  // Returns whether this DAG is an `variadic` specifier.
+  bool isVariadic() const;
+
   // Returns true if this DAG node is an operation.
   bool isOperation() const;
 
@@ -268,9 +272,94 @@ class SymbolInfoMap {
     // Allow SymbolInfoMap to access private methods.
     friend class SymbolInfoMap;
 
-    // DagNode and DagLeaf are accessed by value which means it can't be used as
-    // identifier here. Use an opaque pointer type instead.
-    using DagAndConstant = std::pair<const void *, int>;
+    // Structure to uniquely distinguish 
diff erent locations of the symbols.
+    //
+    // * If a symbol is defined as an operand of an operation, `dag` specifies
+    //   the DAG of the operation, `operandIndexOrNumValues` specifies the
+    //   operand index, and `variadicSubIndex` must be set to `std::nullopt`.
+    //
+    // * If a symbol is defined in a `variadic` DAG, `dag` specifies the DAG
+    //   of the parent operation, `operandIndexOrNumValues` specifies the
+    //   declared operand index of the variadic operand in the parent
+    //   operation.
+    //
+    //   - If the symbol is defined as a result of `variadic` DAG, the
+    //     `variadicSubIndex` must be set to `std::nullopt`, which means that
+    //     the symbol binds to the full operand range.
+    //
+    //   - If the symbol is defined as a operand, the `variadicSubIndex` must
+    //     be set to the index within the variadic sub-operand list.
+    //
+    // * If a symbol is defined in a `either` DAG, `dag` specifies the DAG
+    //   of the parent operation, `operandIndexOrNumValues` specifies the
+    //   operand index in the parent operation (not necessary the index in the
+    //   DAG).
+    //
+    // * If a symbol is defined as a result, specifies the number of returning
+    //   value.
+    //
+    // Example 1:
+    //
+    //   def : Pat<(OpA $input0, $input1), ...>;
+    //
+    //   $input0: (OpA, 0, nullopt)
+    //   $input1: (OpA, 1, nullopt)
+    //
+    // Example 2:
+    //
+    //   def : Pat<(OpB (variadic:$input0 $input0a, $input0b),
+    //                  (variadic:$input1 $input1a, $input1b, $input1c)),
+    //             ...>;
+    //
+    //   $input0:  (OpB, 0, nullopt)
+    //   $input0a: (OpB, 0, 0)
+    //   $input0b: (OpB, 0, 1)
+    //   $input1:  (OpB, 1, nullopt)
+    //   $input1a: (OpB, 1, 0)
+    //   $input1b: (OpB, 1, 1)
+    //   $input1c: (OpB, 1, 2)
+    //
+    // Example 3:
+    //
+    //   def : Pat<(OpC $input0, (either $input1, $input2)), ...>;
+    //
+    //   $input0: (OpC, 0, nullopt)
+    //   $input1: (OpC, 1, nullopt)
+    //   $input2: (OpC, 2, nullopt)
+    //
+    // Example 4:
+    //
+    //   def ThreeResultOp : TEST_Op<...> {
+    //     let results = (outs
+    //       AnyType:$result1,
+    //       AnyType:$result2,
+    //       AnyType:$result3
+    //     );
+    //   }
+    //
+    //   def : Pat<...,
+    //             (ThreeResultOp:$result ...)>;
+    //
+    //   $result: (nullptr, 3, nullopt)
+    //
+    struct DagAndConstant {
+      // DagNode and DagLeaf are accessed by value which means it can't be used
+      // as identifier here. Use an opaque pointer type instead.
+      const void *dag;
+      int operandIndexOrNumValues;
+      std::optional<int> variadicSubIndex;
+
+      DagAndConstant(const void *dag, int operandIndexOrNumValues,
+                     std::optional<int> variadicSubIndex)
+          : dag(dag), operandIndexOrNumValues(operandIndexOrNumValues),
+            variadicSubIndex(variadicSubIndex) {}
+
+      bool operator==(const DagAndConstant &rhs) const {
+        return dag == rhs.dag &&
+               operandIndexOrNumValues == rhs.operandIndexOrNumValues &&
+               variadicSubIndex == rhs.variadicSubIndex;
+      }
+    };
 
     // What kind of entity this symbol represents:
     // * Attr: op attribute
@@ -288,14 +377,18 @@ class SymbolInfoMap {
 
     // Static methods for creating SymbolInfo.
     static SymbolInfo getAttr(const Operator *op, int index) {
-      return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index));
+      return SymbolInfo(op, Kind::Attr,
+                        DagAndConstant(nullptr, index, std::nullopt));
     }
     static SymbolInfo getAttr() {
       return SymbolInfo(nullptr, Kind::Attr, std::nullopt);
     }
-    static SymbolInfo getOperand(DagNode node, const Operator *op, int index) {
+    static SymbolInfo
+    getOperand(DagNode node, const Operator *op, int operandIndex,
+               std::optional<int> variadicSubIndex = std::nullopt) {
       return SymbolInfo(op, Kind::Operand,
-                        DagAndConstant(node.getAsOpaquePointer(), index));
+                        DagAndConstant(node.getAsOpaquePointer(), operandIndex,
+                                       variadicSubIndex));
     }
     static SymbolInfo getResult(const Operator *op) {
       return SymbolInfo(op, Kind::Result, std::nullopt);
@@ -305,7 +398,7 @@ class SymbolInfoMap {
     }
     static SymbolInfo getMultipleValues(int numValues) {
       return SymbolInfo(nullptr, Kind::MultipleValues,
-                        DagAndConstant(nullptr, numValues));
+                        DagAndConstant(nullptr, numValues, std::nullopt));
     }
 
     // Returns the number of static values this symbol corresponds to.
@@ -333,18 +426,23 @@ class SymbolInfoMap {
                                const char *separator) const;
 
     // The argument index (for `Attr` and `Operand` only)
-    int getArgIndex() const { return (*dagAndConstant).second; }
+    int getArgIndex() const { return dagAndConstant->operandIndexOrNumValues; }
 
     // The number of values in the MultipleValue
-    int getSize() const { return (*dagAndConstant).second; }
+    int getSize() const { return dagAndConstant->operandIndexOrNumValues; }
+
+    // The variadic sub-operands index (for variadic `Operand` only)
+    std::optional<int> getVariadicSubIndex() const {
+      return dagAndConstant->variadicSubIndex;
+    }
 
     const Operator *op; // The op where the bound entity belongs
     Kind kind;          // The kind of the bound entity
 
-    // The pair of DagNode pointer and constant value (for `Attr`, `Operand` and
-    // the size of MultipleValue symbol). Note that operands may be bound to the
-    // same symbol, use the DagNode and index to distinguish them. For `Attr`
-    // and MultipleValue, the Dag part will be nullptr.
+    // The tuple of DagNode pointer and two constant values (for `Attr`,
+    // `Operand` and the size of MultipleValue symbol). Note that operands may
+    // be bound to the same symbol, use the DagNode and index to distinguish
+    // them. For `Attr` and MultipleValue, the Dag part will be nullptr.
     std::optional<DagAndConstant> dagAndConstant;
 
     // Alternative name for the symbol. It is used in case the name
@@ -367,7 +465,8 @@ class SymbolInfoMap {
   // Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
   // Returns false if `symbol` is already bound and symbols are not operands.
   bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op,
-                      int argIndex);
+                      int argIndex,
+                      std::optional<int> variadicSubIndex = std::nullopt);
 
   // Binds the given `symbol` to the results the given `op`. Returns false if
   // `symbol` is already bound.
@@ -397,7 +496,8 @@ class SymbolInfoMap {
   // Returns an iterator to the information of the given symbol named as `key`,
   // with index `argIndex` for operator `op`.
   const_iterator findBoundSymbol(StringRef key, DagNode node,
-                                 const Operator &op, int argIndex) const;
+                                 const Operator &op, int argIndex,
+                                 std::optional<int> variadicSubIndex) const;
   const_iterator findBoundSymbol(StringRef key,
                                  const SymbolInfo &symbolInfo) const;
 

diff  --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index d9e1d6c7f89152..3526192c9ee850 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -115,7 +115,8 @@ bool DagNode::isNativeCodeCall() const {
 
 bool DagNode::isOperation() const {
   return !isNativeCodeCall() && !isReplaceWithValue() &&
-         !isLocationDirective() && !isReturnTypeDirective() && !isEither();
+         !isLocationDirective() && !isReturnTypeDirective() && !isEither() &&
+         !isVariadic();
 }
 
 llvm::StringRef DagNode::getNativeCodeTemplate() const {
@@ -193,6 +194,11 @@ bool DagNode::isEither() const {
   return dagOpDef->getName() == "either";
 }
 
+bool DagNode::isVariadic() const {
+  auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
+  return dagOpDef->getName() == "variadic";
+}
+
 void DagNode::print(raw_ostream &os) const {
   if (node)
     node->print(os);
@@ -296,9 +302,10 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
   case Kind::Operand: {
     assert(index < 0);
     auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
-    // If this operand is variadic, then return a range. Otherwise, return the
-    // value itself.
-    if (operand->isVariableLength()) {
+    // If this operand is variadic and this SymbolInfo doesn't have a range
+    // index, then return the full variadic operand_range. Otherwise, return
+    // the value itself.
+    if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
       auto repl = formatv(fmt, name);
       LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
       return std::string(repl);
@@ -426,7 +433,8 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
 }
 
 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
-                                   const Operator &op, int argIndex) {
+                                   const Operator &op, int argIndex,
+                                   std::optional<int> variadicSubIndex) {
   StringRef name = getValuePackName(symbol);
   if (name != symbol) {
     auto error = formatv(
@@ -434,9 +442,10 @@ bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
     PrintFatalError(loc, error);
   }
 
-  auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
-                     ? SymbolInfo::getAttr(&op, argIndex)
-                     : SymbolInfo::getOperand(node, &op, argIndex);
+  auto symInfo =
+      op.getArg(argIndex).is<NamedAttribute *>()
+          ? SymbolInfo::getAttr(&op, argIndex)
+          : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
 
   std::string key = symbol.str();
   if (symbolInfoMap.count(key)) {
@@ -499,8 +508,10 @@ SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
 
 SymbolInfoMap::const_iterator
 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
-                               int argIndex) const {
-  return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex));
+                               int argIndex,
+                               std::optional<int> variadicSubIndex) const {
+  return findBoundSymbol(
+      key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
 }
 
 SymbolInfoMap::const_iterator
@@ -831,6 +842,33 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
       }
     };
 
+    // The operand in `variadic` DAG should be bound to the operation in the
+    // parent DagNode. The range index must be included as well to distinguish
+    // (potentially) repeating argName within the `variadic` DAG.
+    auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
+                                       int opArgIdx) {
+      auto treeName = tree.getSymbol();
+      if (!treeName.empty()) {
+        // If treeName is specified, bind to the full variadic operand_range.
+        verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
+                                          std::nullopt),
+                   treeName);
+      }
+
+      for (int i = 0; i < tree.getNumArgs(); ++i) {
+        if (DagNode subTree = tree.getArgAsNestedDag(i)) {
+          collectBoundSymbols(subTree, infoMap, isSrcPattern);
+        } else {
+          auto argName = tree.getArgName(i);
+          if (!argName.empty() && argName != "_") {
+            verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
+                                              /*variadicSubIndex=*/i),
+                       argName);
+          }
+        }
+      }
+    };
+
     for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
       if (auto treeArg = tree.getArgAsNestedDag(i)) {
         if (treeArg.isEither()) {
@@ -843,6 +881,8 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
           //
           //  (FooOp arg0, arg1, arg2)
           ++opArgIdx;
+        } else if (treeArg.isVariadic()) {
+          collectSymbolInVariadic(tree, treeArg, opArgIdx);
         } else {
           // This DAG node argument is a DAG node itself. Go inside recursively.
           collectBoundSymbols(treeArg, infoMap, isSrcPattern);

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bb48278ee6b8b3..a684ebdbe4c356 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1646,6 +1646,76 @@ def : Pat<(OneI32ResultOp),
               (replaceWithValue $results__2),
               ConstantAttr<I32Attr, "2">)>;
 
+// Variadic structured matching
+def MixedVOperandOp4 : TEST_Op<"mixed_variadic_in4"> {
+  let arguments = (ins
+    Variadic<I32>:$input1,
+    I32:$input2,
+    I32Attr:$attr1
+  );
+}
+
+def MixedVOperandOp5 : TEST_Op<"mixed_variadic_in5"> {
+  let arguments = (ins
+    I32:$input1,
+    I32:$input2,
+    I32:$input3,
+    I32Attr:$attr1,
+    StrAttr:$pattern_name
+  );
+}
+
+// Helper op to test variadic recursive pattern matching
+def MixedVOperandInOutI32Op : TEST_Op<"mixed_variadic_in_out_i32"> {
+  let arguments = (ins
+    I32:$input
+  );
+  let results = (outs
+    I32:$output
+  );
+}
+
+def : Pat<
+  (MixedVOperandOp4 (variadic $input1a, $input1b), $input2,
+                    ConstantAttr<I32Attr, "0">:$attr1),
+  (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
+                    ConstantStrAttr<StrAttr, "MatchVariadic">)>;
+
+def : Pat<
+  (MixedVOperandOp4 (variadic (MixedVOperandInOutI32Op $input1a),
+                              (MixedVOperandInOutI32Op $input1b)),
+                    $input2, ConstantAttr<I32Attr, "1">:$attr1),
+  (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
+                    ConstantStrAttr<StrAttr, "MatchVariadicSubDag">)>;
+
+def : Pat<
+  (MixedVOperandOp4 (variadic $input1, $input1), $input2,
+                    ConstantAttr<I32Attr, "2">:$attr1),
+  (MixedVOperandOp5 $input1, $input1, $input2, $attr1,
+                    ConstantStrAttr<StrAttr, "MatchVariadicSameSymbol">)>;
+
+def MixedVOperandOp6 : TEST_Op<"mixed_variadic_in6",
+                               [SameVariadicOperandSize]> {
+  let arguments = (ins
+    Variadic<I32>:$input1,
+    Variadic<I32>:$input2,
+    I32Attr:$attr1
+  );
+}
+
+def : Pat<
+  (MixedVOperandOp6 (variadic:$input1 $input1a, $input1b),
+                    (variadic:$input2 $input2a, $input2b),
+                    ConstantAttr<I32Attr, "1">:$attr1),
+  (MixedVOperandOp6 $input2, $input1, ConstantAttr<I32Attr, "-1">)>;
+
+def : Pat<
+  (MixedVOperandOp6 (variadic $input1a, $input1b),
+                    (variadic $input2a, $input2b),
+                    ConstantAttr<I32Attr, "2">:$attr1),
+  (MixedVOperandOp5 $input2a, $input2b, $input1b, $attr1,
+                    ConstantStrAttr<StrAttr, "MatchMultiVariadicSubSymbol">)>;
+
 //===----------------------------------------------------------------------===//
 // Test Patterns (either)
 

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 4c1182fa9eb729..5f776338bd40be 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -515,6 +515,67 @@ func.func @generateVariadicOutputOpInNestedPattern() -> (i32) {
   return %0 : i32
 }
 
+// CHECK-LABEL: @testMatchVariadic
+func.func @testMatchVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () {
+  // CHECK: "test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 0 : i32, pattern_name = "MatchVariadic"}> : (i32, i32, i32) -> ()
+  "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) {attr1 = 0 : i32} : (i32, i32, i32) -> ()
+
+  // Note: Not rewritten because variadic operand size mismatches.
+  // CHECK: "test.mixed_variadic_in4"(%arg0, %arg1, %arg2, %arg3) <{attr1 = 0 : i32}> : (i32, i32, i32, i32) -> ()
+  "test.mixed_variadic_in4"(%arg0, %arg1, %arg2, %arg3) {attr1 = 0 : i32} : (i32, i32, i32, i32) -> ()
+
+  return
+}
+
+// CHECK-LABEL: @testMatchVariadicSubDag
+func.func @testMatchVariadicSubDag(%arg0: i32, %arg1: i32, %arg2: i32) -> () {
+  // CHECK: %[[IN0:.*]] = "test.mixed_variadic_in_out_i32"(%arg0) : (i32) -> i32
+  %0 = "test.mixed_variadic_in_out_i32"(%arg0) : (i32) -> i32
+  // CHECK: %[[IN1:.*]] = "test.mixed_variadic_in_out_i32"(%arg1) : (i32) -> i32
+  %1 = "test.mixed_variadic_in_out_i32"(%arg1) : (i32) -> i32
+
+  // CHECK: "test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 1 : i32, pattern_name = "MatchVariadicSubDag"}> : (i32, i32, i32) -> ()
+  "test.mixed_variadic_in4"(%0, %1, %arg2) {attr1 = 1 : i32} : (i32, i32, i32) -> ()
+
+  // Note: MatchVariadicSubDag doesn't apply
+  // CHECK: "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) <{attr1 = 1 : i32}> : (i32, i32, i32) -> ()
+  "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) {attr1 = 1 : i32} : (i32, i32, i32) -> ()
+
+  return
+}
+
+// CHECK-LABEL: @testMatchVariadicSameSymbol
+func.func @testMatchVariadicSameSymbol(%arg0: i32, %arg1: i32, %arg2: i32) -> () {
+  // CHECK: "test.mixed_variadic_in5"(%arg0, %arg0, %arg2) <{attr1 = 2 : i32, pattern_name = "MatchVariadicSameSymbol"}> : (i32, i32, i32) -> ()
+  "test.mixed_variadic_in4"(%arg0, %arg0, %arg2) {attr1 = 2 : i32} : (i32, i32, i32) -> ()
+
+  // Note: MatchVariadicSameSymbol doesn't apply.
+  // CHECK: "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) <{attr1 = 2 : i32}> : (i32, i32, i32) -> ()
+  "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) {attr1 = 2 : i32} : (i32, i32, i32) -> ()
+
+  return
+}
+
+// CHECK-LABEL: @testMatchAndRewriteVariadicFullRange
+func.func @testMatchAndRewriteVariadicFullRange(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () {
+  // CHECK: "test.mixed_variadic_in6"(%arg2, %arg3, %arg0, %arg1) <{attr1 = -1 : i32}> : (i32, i32, i32, i32) -> ()
+  "test.mixed_variadic_in6"(%arg0, %arg1, %arg2, %arg3) {attr1 = 1 : i32} : (i32, i32, i32, i32) -> ()
+
+  // Note: MatchAndRewriteVariadicFullRange doesn't apply because the length of each variadic operand is not equal to 2.
+  // CHECK: "test.mixed_variadic_in6"(%arg0, %arg1) <{attr1 = 1 : i32}> : (i32, i32) -> ()
+  "test.mixed_variadic_in6"(%arg0, %arg1) {attr1 = 1 : i32} : (i32, i32) -> ()
+
+  return
+}
+
+// CHECK-LABEL: @testMatchMultiVariadicSubSymbol
+func.func @testMatchMultiVariadicSubSymbol(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () {
+  // CHECK: "test.mixed_variadic_in5"(%arg2, %arg3, %arg1) <{attr1 = 2 : i32, pattern_name = "MatchMultiVariadicSubSymbol"}> : (i32, i32, i32) -> ()
+  "test.mixed_variadic_in6"(%arg0, %arg1, %arg2, %arg3) {attr1 = 2 : i32} : (i32, i32, i32, i32) -> ()
+
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // Test that natives calls are only called once during rewrites.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 8b5ef5c6e01829..875f5b71de7f3c 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -103,8 +103,9 @@ class PatternEmitter {
   // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
   // bound name and the constraint of the operand respectively.
   void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
-                        DagLeaf operandMatcher, StringRef argName,
-                        int argIndex);
+                        int operandIndex, DagLeaf operandMatcher,
+                        StringRef argName, int argIndex,
+                        std::optional<int> variadicSubIndex);
 
   // Emits C++ statements for matching the operands which can be matched in
   // either order.
@@ -112,6 +113,11 @@ class PatternEmitter {
                               StringRef opName, int argIndex, int &operandIndex,
                               int depth);
 
+  // Emits C++ statements for matching a variadic operand.
+  void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree,
+                                StringRef opName, int argIndex,
+                                int &operandIndex, int depth);
+
   // Emits C++ statements for matching the `argIndex`-th argument of the given
   // DAG `tree` as an attribute.
   void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
@@ -262,6 +268,11 @@ class StaticMatcherHelper {
   // Determine if we should inline the match logic or delegate to a static
   // function.
   bool useStaticMatcher(DagNode node) {
+    // either/variadic node must be associated to the parentOp, thus we can't
+    // emit a static matcher rooted at them.
+    if (node.isEither() || node.isVariadic())
+      return false;
+
     return refStats[node] > kStaticMatcherThreshold;
   }
 
@@ -462,6 +473,8 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
       if (argTree.isEither())
         PrintFatalError(loc, "NativeCodeCall cannot have `either` operands");
+      if (argTree.isVariadic())
+        PrintFatalError(loc, "NativeCodeCall cannot have `variadic` operands");
 
       os << "::mlir::Value " << argName << ";\n";
     } else {
@@ -596,6 +609,18 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
         continue;
       }
       if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
+        if (argTree.isVariadic()) {
+          if (!operand->isVariadic()) {
+            auto error = formatv("variadic DAG construct can't match op {0}'s "
+                                 "non-variadic operand #{1}",
+                                 op.getOperationName(), opArgIdx);
+            PrintFatalError(loc, error);
+          }
+          emitVariadicOperandMatch(tree, argTree, castedName, opArgIdx,
+                                   nextOperand, depth);
+          ++nextOperand;
+          continue;
+        }
         if (operand->isVariableLength()) {
           auto error = formatv("use nested DAG construct to match op {0}'s "
                                "variadic operand #{1} unsupported now",
@@ -627,9 +652,10 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
     if (opArg.is<NamedTypeConstraint *>()) {
       auto operandName =
           formatv("{0}.getODSOperands({1})", castedName, nextOperand);
-      emitOperandMatch(tree, castedName, operandName.str(),
+      emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
                        /*operandMatcher=*/tree.getArgAsLeaf(i),
-                       /*argName=*/tree.getArgName(i), opArgIdx);
+                       /*argName=*/tree.getArgName(i), opArgIdx,
+                       /*variadicSubIndex=*/std::nullopt);
       ++nextOperand;
     } else if (opArg.is<NamedAttribute *>()) {
       emitAttributeMatch(tree, opName, opArgIdx, depth);
@@ -643,11 +669,12 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
 }
 
 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
-                                      StringRef operandName,
+                                      StringRef operandName, int operandIndex,
                                       DagLeaf operandMatcher, StringRef argName,
-                                      int argIndex) {
+                                      int argIndex,
+                                      std::optional<int> variadicSubIndex) {
   Operator &op = tree.getDialectOp(opMap);
-  auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
+  auto *operand = op.getArg(operandIndex).get<NamedTypeConstraint *>();
 
   // If a constraint is specified, we need to generate C++ statements to
   // check the constraint.
@@ -682,7 +709,11 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
   // Capture the value
   // `$_` is a special symbol to ignore op argument matching.
   if (!argName.empty() && argName != "_") {
-    auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex);
+    auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
+                                             variadicSubIndex);
+    if (res == symbolInfoMap.end())
+      PrintFatalError(loc, formatv("symbol not found: {0}", argName));
+
     os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName);
   }
 }
@@ -735,8 +766,10 @@ void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
       tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
     } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
       emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
+                       operandIndex,
                        /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
-                       /*argName=*/eitherArgTree.getArgName(i), argIndex);
+                       /*argName=*/eitherArgTree.getArgName(i), argIndex,
+                       /*variadicSubIndex=*/std::nullopt);
       ++operandIndex;
     } else {
       PrintFatalError(loc, "either can only be applied on operand");
@@ -764,6 +797,67 @@ void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
   os.unindent().unindent() << "}\n";
 }
 
+void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
+                                              DagNode variadicArgTree,
+                                              StringRef opName, int argIndex,
+                                              int &operandIndex, int depth) {
+  Operator &op = tree.getDialectOp(opMap);
+
+  os << "{\n";
+  os.indent();
+
+  os << formatv("auto variadic_operand_range = {0}.getODSOperands({1});\n",
+                opName, operandIndex);
+  os << formatv("if (variadic_operand_range.size() != {0}) "
+                "return ::mlir::failure();\n",
+                variadicArgTree.getNumArgs());
+
+  StringRef variadicTreeName = variadicArgTree.getSymbol();
+  if (!variadicTreeName.empty()) {
+    auto res =
+        symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
+                                      /*variadicSubIndex=*/std::nullopt);
+    if (res == symbolInfoMap.end())
+      PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));
+
+    os << formatv("{0} = variadic_operand_range;\n",
+                  res->second.getVarName(variadicTreeName));
+  }
+
+  for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) {
+    if (DagNode argTree = variadicArgTree.getArgAsNestedDag(i)) {
+      if (!argTree.isOperation())
+        PrintFatalError(loc, "variadic only accepts operation sub-dags");
+
+      os << "{\n";
+      os.indent();
+
+      std::string argName = formatv("local_op_{0}", i).str();
+      os << formatv("auto *{0} = "
+                    "variadic_operand_range[{1}].getDefiningOp();\n",
+                    argName, i);
+      emitMatchCheck(
+          opName, /*matchStr=*/argName,
+          formatv("\"There's no operation that defines variadic operand "
+                  "{0} (variadic sub-opearnd #{1}) of {2}\"",
+                  operandIndex, i, opName));
+      emitMatch(argTree, argName, depth + 1);
+      os << formatv("tblgen_ops.push_back({0});\n", argName);
+
+      os.unindent() << "}\n";
+    } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
+      auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i);
+      emitOperandMatch(tree, opName, operandName.str(), operandIndex,
+                       /*operandMatcher=*/variadicArgTree.getArgAsLeaf(i),
+                       /*argName=*/variadicArgTree.getArgName(i), argIndex, i);
+    } else {
+      PrintFatalError(loc, "variadic can only be applied on operand");
+    }
+  }
+
+  os.unindent() << "}\n";
+}
+
 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
                                         int argIndex, int depth) {
   Operator &op = tree.getDialectOp(opMap);


        


More information about the Mlir-commits mailing list