[Mlir-commits] [mlir] 08d7377 - [mlir] Enable DRR variadic operand matching
Logan Chien
llvmlistbot at llvm.org
Mon Aug 28 14:14:14 PDT 2023
Author: Logan Chien
Date: 2023-08-28T14:11:32-07:00
New Revision: 08d7377b67358496a409080fac22f3f7c077fb63
URL: https://github.com/llvm/llvm-project/commit/08d7377b67358496a409080fac22f3f7c077fb63
DIFF: https://github.com/llvm/llvm-project/commit/08d7377b67358496a409080fac22f3f7c077fb63.diff
LOG: [mlir] Enable DRR variadic operand matching
This commit enables DRR rewriter to match a fixed number of sub-operands
as a variadic operand.
Differential Review: https://reviews.llvm.org/D157359
Added:
Modified:
mlir/docs/DeclarativeRewrites.md
mlir/include/mlir/IR/PatternBase.td
mlir/include/mlir/TableGen/Pattern.h
mlir/lib/TableGen/Pattern.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index dd996baf3cd957..2ae99f4ea32b3e 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -647,6 +647,50 @@ correspond to multiple actual values.
[TODO]
+#### Match variadic operand
+
+Use the `variadic` DAG node to match a variadic operand with a fixed number of
+actual sub-operands.
+
+For example, assume that `ConcatenateOp` is an operation with a variadic
+operand:
+
+```tablegen
+def ConcatenateOp : TEST_Op<"concatenate"> {
+ let arguments = (ins
+ Variadic<AnyTensor>:$inputs,
+ I32Attr:$axis
+ );
+
+ let results = (outs
+ AnyTensor$output
+ );
+}
+```
+
+We can match `ConcatenateOp` with exactly 2 actual operands with:
+
+```tablegen
+def : Pat<(ConcatenateOp (variadic $input0, $input1), $axis),
+ ...>;
+```
+
+The variadic sub-operands can be sub-DAGs to be matched:
+
+```tablegen
+def : Pat<(ConcatenateOp (variadic (SomeOp $a), (AnotherOp $b, $c)), $axis),
+ (OtherOp $a, $b, $c)>;
+```
+
+The variadic DAG can be bound to a symbol, which refers to the full
+`operand_range`:
+
+```tablegen
+def : Pat<(ConcatenateOp (variadic:$inputs $input0, $input1),
+ ConstantAttr<I32Attr, "0">),
+ (VStackOp $inputs)>;
+```
+
### Supplying additional constraints
Constraints can be placed on op arguments when matching. But sometimes we need
diff --git a/mlir/include/mlir/IR/PatternBase.td b/mlir/include/mlir/IR/PatternBase.td
index 919fb884adb0e9..33f8e2ac316d5c 100644
--- a/mlir/include/mlir/IR/PatternBase.td
+++ b/mlir/include/mlir/IR/PatternBase.td
@@ -218,6 +218,22 @@ def returnType;
// `either` while pattern matching.
def either;
+// Directive used to match variadic operands. This directive only matches if
+// the variadic operand has the same length as the specified formal
+// sub-dags.
+//
+// ```
+// (VariadicOp (variadic:$input1 $input1a, $input1b),
+// (variadic:$input2 $input2a, $input2b, $input2c),
+// $attr1, $attr2)
+// ```
+//
+// The pattern above only matches if the `$input1` operand is of length 2,
+// `$input2` is of length 3, and all sub-dags match respectively. The `$input1`
+// symbol denotes the full variadic operand range. The `$input1a` symbol
+// denotes the first operand in the variadic sub-operands.
+def variadic;
+
//===----------------------------------------------------------------------===//
// Common value constraints
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 4511ba7dd833ef..80f38fdeffee07 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -22,6 +22,7 @@
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
+#include <optional>
#include <unordered_map>
namespace llvm {
@@ -189,6 +190,9 @@ class DagNode {
// Returns whether this DAG is an `either` specifier.
bool isEither() const;
+ // Returns whether this DAG is an `variadic` specifier.
+ bool isVariadic() const;
+
// Returns true if this DAG node is an operation.
bool isOperation() const;
@@ -268,9 +272,94 @@ class SymbolInfoMap {
// Allow SymbolInfoMap to access private methods.
friend class SymbolInfoMap;
- // DagNode and DagLeaf are accessed by value which means it can't be used as
- // identifier here. Use an opaque pointer type instead.
- using DagAndConstant = std::pair<const void *, int>;
+ // Structure to uniquely distinguish
diff erent locations of the symbols.
+ //
+ // * If a symbol is defined as an operand of an operation, `dag` specifies
+ // the DAG of the operation, `operandIndexOrNumValues` specifies the
+ // operand index, and `variadicSubIndex` must be set to `std::nullopt`.
+ //
+ // * If a symbol is defined in a `variadic` DAG, `dag` specifies the DAG
+ // of the parent operation, `operandIndexOrNumValues` specifies the
+ // declared operand index of the variadic operand in the parent
+ // operation.
+ //
+ // - If the symbol is defined as a result of `variadic` DAG, the
+ // `variadicSubIndex` must be set to `std::nullopt`, which means that
+ // the symbol binds to the full operand range.
+ //
+ // - If the symbol is defined as a operand, the `variadicSubIndex` must
+ // be set to the index within the variadic sub-operand list.
+ //
+ // * If a symbol is defined in a `either` DAG, `dag` specifies the DAG
+ // of the parent operation, `operandIndexOrNumValues` specifies the
+ // operand index in the parent operation (not necessary the index in the
+ // DAG).
+ //
+ // * If a symbol is defined as a result, specifies the number of returning
+ // value.
+ //
+ // Example 1:
+ //
+ // def : Pat<(OpA $input0, $input1), ...>;
+ //
+ // $input0: (OpA, 0, nullopt)
+ // $input1: (OpA, 1, nullopt)
+ //
+ // Example 2:
+ //
+ // def : Pat<(OpB (variadic:$input0 $input0a, $input0b),
+ // (variadic:$input1 $input1a, $input1b, $input1c)),
+ // ...>;
+ //
+ // $input0: (OpB, 0, nullopt)
+ // $input0a: (OpB, 0, 0)
+ // $input0b: (OpB, 0, 1)
+ // $input1: (OpB, 1, nullopt)
+ // $input1a: (OpB, 1, 0)
+ // $input1b: (OpB, 1, 1)
+ // $input1c: (OpB, 1, 2)
+ //
+ // Example 3:
+ //
+ // def : Pat<(OpC $input0, (either $input1, $input2)), ...>;
+ //
+ // $input0: (OpC, 0, nullopt)
+ // $input1: (OpC, 1, nullopt)
+ // $input2: (OpC, 2, nullopt)
+ //
+ // Example 4:
+ //
+ // def ThreeResultOp : TEST_Op<...> {
+ // let results = (outs
+ // AnyType:$result1,
+ // AnyType:$result2,
+ // AnyType:$result3
+ // );
+ // }
+ //
+ // def : Pat<...,
+ // (ThreeResultOp:$result ...)>;
+ //
+ // $result: (nullptr, 3, nullopt)
+ //
+ struct DagAndConstant {
+ // DagNode and DagLeaf are accessed by value which means it can't be used
+ // as identifier here. Use an opaque pointer type instead.
+ const void *dag;
+ int operandIndexOrNumValues;
+ std::optional<int> variadicSubIndex;
+
+ DagAndConstant(const void *dag, int operandIndexOrNumValues,
+ std::optional<int> variadicSubIndex)
+ : dag(dag), operandIndexOrNumValues(operandIndexOrNumValues),
+ variadicSubIndex(variadicSubIndex) {}
+
+ bool operator==(const DagAndConstant &rhs) const {
+ return dag == rhs.dag &&
+ operandIndexOrNumValues == rhs.operandIndexOrNumValues &&
+ variadicSubIndex == rhs.variadicSubIndex;
+ }
+ };
// What kind of entity this symbol represents:
// * Attr: op attribute
@@ -288,14 +377,18 @@ class SymbolInfoMap {
// Static methods for creating SymbolInfo.
static SymbolInfo getAttr(const Operator *op, int index) {
- return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index));
+ return SymbolInfo(op, Kind::Attr,
+ DagAndConstant(nullptr, index, std::nullopt));
}
static SymbolInfo getAttr() {
return SymbolInfo(nullptr, Kind::Attr, std::nullopt);
}
- static SymbolInfo getOperand(DagNode node, const Operator *op, int index) {
+ static SymbolInfo
+ getOperand(DagNode node, const Operator *op, int operandIndex,
+ std::optional<int> variadicSubIndex = std::nullopt) {
return SymbolInfo(op, Kind::Operand,
- DagAndConstant(node.getAsOpaquePointer(), index));
+ DagAndConstant(node.getAsOpaquePointer(), operandIndex,
+ variadicSubIndex));
}
static SymbolInfo getResult(const Operator *op) {
return SymbolInfo(op, Kind::Result, std::nullopt);
@@ -305,7 +398,7 @@ class SymbolInfoMap {
}
static SymbolInfo getMultipleValues(int numValues) {
return SymbolInfo(nullptr, Kind::MultipleValues,
- DagAndConstant(nullptr, numValues));
+ DagAndConstant(nullptr, numValues, std::nullopt));
}
// Returns the number of static values this symbol corresponds to.
@@ -333,18 +426,23 @@ class SymbolInfoMap {
const char *separator) const;
// The argument index (for `Attr` and `Operand` only)
- int getArgIndex() const { return (*dagAndConstant).second; }
+ int getArgIndex() const { return dagAndConstant->operandIndexOrNumValues; }
// The number of values in the MultipleValue
- int getSize() const { return (*dagAndConstant).second; }
+ int getSize() const { return dagAndConstant->operandIndexOrNumValues; }
+
+ // The variadic sub-operands index (for variadic `Operand` only)
+ std::optional<int> getVariadicSubIndex() const {
+ return dagAndConstant->variadicSubIndex;
+ }
const Operator *op; // The op where the bound entity belongs
Kind kind; // The kind of the bound entity
- // The pair of DagNode pointer and constant value (for `Attr`, `Operand` and
- // the size of MultipleValue symbol). Note that operands may be bound to the
- // same symbol, use the DagNode and index to distinguish them. For `Attr`
- // and MultipleValue, the Dag part will be nullptr.
+ // The tuple of DagNode pointer and two constant values (for `Attr`,
+ // `Operand` and the size of MultipleValue symbol). Note that operands may
+ // be bound to the same symbol, use the DagNode and index to distinguish
+ // them. For `Attr` and MultipleValue, the Dag part will be nullptr.
std::optional<DagAndConstant> dagAndConstant;
// Alternative name for the symbol. It is used in case the name
@@ -367,7 +465,8 @@ class SymbolInfoMap {
// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
// Returns false if `symbol` is already bound and symbols are not operands.
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op,
- int argIndex);
+ int argIndex,
+ std::optional<int> variadicSubIndex = std::nullopt);
// Binds the given `symbol` to the results the given `op`. Returns false if
// `symbol` is already bound.
@@ -397,7 +496,8 @@ class SymbolInfoMap {
// Returns an iterator to the information of the given symbol named as `key`,
// with index `argIndex` for operator `op`.
const_iterator findBoundSymbol(StringRef key, DagNode node,
- const Operator &op, int argIndex) const;
+ const Operator &op, int argIndex,
+ std::optional<int> variadicSubIndex) const;
const_iterator findBoundSymbol(StringRef key,
const SymbolInfo &symbolInfo) const;
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index d9e1d6c7f89152..3526192c9ee850 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -115,7 +115,8 @@ bool DagNode::isNativeCodeCall() const {
bool DagNode::isOperation() const {
return !isNativeCodeCall() && !isReplaceWithValue() &&
- !isLocationDirective() && !isReturnTypeDirective() && !isEither();
+ !isLocationDirective() && !isReturnTypeDirective() && !isEither() &&
+ !isVariadic();
}
llvm::StringRef DagNode::getNativeCodeTemplate() const {
@@ -193,6 +194,11 @@ bool DagNode::isEither() const {
return dagOpDef->getName() == "either";
}
+bool DagNode::isVariadic() const {
+ auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
+ return dagOpDef->getName() == "variadic";
+}
+
void DagNode::print(raw_ostream &os) const {
if (node)
node->print(os);
@@ -296,9 +302,10 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
case Kind::Operand: {
assert(index < 0);
auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
- // If this operand is variadic, then return a range. Otherwise, return the
- // value itself.
- if (operand->isVariableLength()) {
+ // If this operand is variadic and this SymbolInfo doesn't have a range
+ // index, then return the full variadic operand_range. Otherwise, return
+ // the value itself.
+ if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
auto repl = formatv(fmt, name);
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
return std::string(repl);
@@ -426,7 +433,8 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
}
bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
- const Operator &op, int argIndex) {
+ const Operator &op, int argIndex,
+ std::optional<int> variadicSubIndex) {
StringRef name = getValuePackName(symbol);
if (name != symbol) {
auto error = formatv(
@@ -434,9 +442,10 @@ bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
PrintFatalError(loc, error);
}
- auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
- ? SymbolInfo::getAttr(&op, argIndex)
- : SymbolInfo::getOperand(node, &op, argIndex);
+ auto symInfo =
+ op.getArg(argIndex).is<NamedAttribute *>()
+ ? SymbolInfo::getAttr(&op, argIndex)
+ : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
std::string key = symbol.str();
if (symbolInfoMap.count(key)) {
@@ -499,8 +508,10 @@ SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
SymbolInfoMap::const_iterator
SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
- int argIndex) const {
- return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex));
+ int argIndex,
+ std::optional<int> variadicSubIndex) const {
+ return findBoundSymbol(
+ key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
}
SymbolInfoMap::const_iterator
@@ -831,6 +842,33 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
}
};
+ // The operand in `variadic` DAG should be bound to the operation in the
+ // parent DagNode. The range index must be included as well to distinguish
+ // (potentially) repeating argName within the `variadic` DAG.
+ auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
+ int opArgIdx) {
+ auto treeName = tree.getSymbol();
+ if (!treeName.empty()) {
+ // If treeName is specified, bind to the full variadic operand_range.
+ verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
+ std::nullopt),
+ treeName);
+ }
+
+ for (int i = 0; i < tree.getNumArgs(); ++i) {
+ if (DagNode subTree = tree.getArgAsNestedDag(i)) {
+ collectBoundSymbols(subTree, infoMap, isSrcPattern);
+ } else {
+ auto argName = tree.getArgName(i);
+ if (!argName.empty() && argName != "_") {
+ verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
+ /*variadicSubIndex=*/i),
+ argName);
+ }
+ }
+ }
+ };
+
for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
if (auto treeArg = tree.getArgAsNestedDag(i)) {
if (treeArg.isEither()) {
@@ -843,6 +881,8 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
//
// (FooOp arg0, arg1, arg2)
++opArgIdx;
+ } else if (treeArg.isVariadic()) {
+ collectSymbolInVariadic(tree, treeArg, opArgIdx);
} else {
// This DAG node argument is a DAG node itself. Go inside recursively.
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bb48278ee6b8b3..a684ebdbe4c356 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1646,6 +1646,76 @@ def : Pat<(OneI32ResultOp),
(replaceWithValue $results__2),
ConstantAttr<I32Attr, "2">)>;
+// Variadic structured matching
+def MixedVOperandOp4 : TEST_Op<"mixed_variadic_in4"> {
+ let arguments = (ins
+ Variadic<I32>:$input1,
+ I32:$input2,
+ I32Attr:$attr1
+ );
+}
+
+def MixedVOperandOp5 : TEST_Op<"mixed_variadic_in5"> {
+ let arguments = (ins
+ I32:$input1,
+ I32:$input2,
+ I32:$input3,
+ I32Attr:$attr1,
+ StrAttr:$pattern_name
+ );
+}
+
+// Helper op to test variadic recursive pattern matching
+def MixedVOperandInOutI32Op : TEST_Op<"mixed_variadic_in_out_i32"> {
+ let arguments = (ins
+ I32:$input
+ );
+ let results = (outs
+ I32:$output
+ );
+}
+
+def : Pat<
+ (MixedVOperandOp4 (variadic $input1a, $input1b), $input2,
+ ConstantAttr<I32Attr, "0">:$attr1),
+ (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
+ ConstantStrAttr<StrAttr, "MatchVariadic">)>;
+
+def : Pat<
+ (MixedVOperandOp4 (variadic (MixedVOperandInOutI32Op $input1a),
+ (MixedVOperandInOutI32Op $input1b)),
+ $input2, ConstantAttr<I32Attr, "1">:$attr1),
+ (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
+ ConstantStrAttr<StrAttr, "MatchVariadicSubDag">)>;
+
+def : Pat<
+ (MixedVOperandOp4 (variadic $input1, $input1), $input2,
+ ConstantAttr<I32Attr, "2">:$attr1),
+ (MixedVOperandOp5 $input1, $input1, $input2, $attr1,
+ ConstantStrAttr<StrAttr, "MatchVariadicSameSymbol">)>;
+
+def MixedVOperandOp6 : TEST_Op<"mixed_variadic_in6",
+ [SameVariadicOperandSize]> {
+ let arguments = (ins
+ Variadic<I32>:$input1,
+ Variadic<I32>:$input2,
+ I32Attr:$attr1
+ );
+}
+
+def : Pat<
+ (MixedVOperandOp6 (variadic:$input1 $input1a, $input1b),
+ (variadic:$input2 $input2a, $input2b),
+ ConstantAttr<I32Attr, "1">:$attr1),
+ (MixedVOperandOp6 $input2, $input1, ConstantAttr<I32Attr, "-1">)>;
+
+def : Pat<
+ (MixedVOperandOp6 (variadic $input1a, $input1b),
+ (variadic $input2a, $input2b),
+ ConstantAttr<I32Attr, "2">:$attr1),
+ (MixedVOperandOp5 $input2a, $input2b, $input1b, $attr1,
+ ConstantStrAttr<StrAttr, "MatchMultiVariadicSubSymbol">)>;
+
//===----------------------------------------------------------------------===//
// Test Patterns (either)
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 4c1182fa9eb729..5f776338bd40be 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -515,6 +515,67 @@ func.func @generateVariadicOutputOpInNestedPattern() -> (i32) {
return %0 : i32
}
+// CHECK-LABEL: @testMatchVariadic
+func.func @testMatchVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () {
+ // CHECK: "test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 0 : i32, pattern_name = "MatchVariadic"}> : (i32, i32, i32) -> ()
+ "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) {attr1 = 0 : i32} : (i32, i32, i32) -> ()
+
+ // Note: Not rewritten because variadic operand size mismatches.
+ // CHECK: "test.mixed_variadic_in4"(%arg0, %arg1, %arg2, %arg3) <{attr1 = 0 : i32}> : (i32, i32, i32, i32) -> ()
+ "test.mixed_variadic_in4"(%arg0, %arg1, %arg2, %arg3) {attr1 = 0 : i32} : (i32, i32, i32, i32) -> ()
+
+ return
+}
+
+// CHECK-LABEL: @testMatchVariadicSubDag
+func.func @testMatchVariadicSubDag(%arg0: i32, %arg1: i32, %arg2: i32) -> () {
+ // CHECK: %[[IN0:.*]] = "test.mixed_variadic_in_out_i32"(%arg0) : (i32) -> i32
+ %0 = "test.mixed_variadic_in_out_i32"(%arg0) : (i32) -> i32
+ // CHECK: %[[IN1:.*]] = "test.mixed_variadic_in_out_i32"(%arg1) : (i32) -> i32
+ %1 = "test.mixed_variadic_in_out_i32"(%arg1) : (i32) -> i32
+
+ // CHECK: "test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 1 : i32, pattern_name = "MatchVariadicSubDag"}> : (i32, i32, i32) -> ()
+ "test.mixed_variadic_in4"(%0, %1, %arg2) {attr1 = 1 : i32} : (i32, i32, i32) -> ()
+
+ // Note: MatchVariadicSubDag doesn't apply
+ // CHECK: "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) <{attr1 = 1 : i32}> : (i32, i32, i32) -> ()
+ "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) {attr1 = 1 : i32} : (i32, i32, i32) -> ()
+
+ return
+}
+
+// CHECK-LABEL: @testMatchVariadicSameSymbol
+func.func @testMatchVariadicSameSymbol(%arg0: i32, %arg1: i32, %arg2: i32) -> () {
+ // CHECK: "test.mixed_variadic_in5"(%arg0, %arg0, %arg2) <{attr1 = 2 : i32, pattern_name = "MatchVariadicSameSymbol"}> : (i32, i32, i32) -> ()
+ "test.mixed_variadic_in4"(%arg0, %arg0, %arg2) {attr1 = 2 : i32} : (i32, i32, i32) -> ()
+
+ // Note: MatchVariadicSameSymbol doesn't apply.
+ // CHECK: "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) <{attr1 = 2 : i32}> : (i32, i32, i32) -> ()
+ "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) {attr1 = 2 : i32} : (i32, i32, i32) -> ()
+
+ return
+}
+
+// CHECK-LABEL: @testMatchAndRewriteVariadicFullRange
+func.func @testMatchAndRewriteVariadicFullRange(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () {
+ // CHECK: "test.mixed_variadic_in6"(%arg2, %arg3, %arg0, %arg1) <{attr1 = -1 : i32}> : (i32, i32, i32, i32) -> ()
+ "test.mixed_variadic_in6"(%arg0, %arg1, %arg2, %arg3) {attr1 = 1 : i32} : (i32, i32, i32, i32) -> ()
+
+ // Note: MatchAndRewriteVariadicFullRange doesn't apply because the length of each variadic operand is not equal to 2.
+ // CHECK: "test.mixed_variadic_in6"(%arg0, %arg1) <{attr1 = 1 : i32}> : (i32, i32) -> ()
+ "test.mixed_variadic_in6"(%arg0, %arg1) {attr1 = 1 : i32} : (i32, i32) -> ()
+
+ return
+}
+
+// CHECK-LABEL: @testMatchMultiVariadicSubSymbol
+func.func @testMatchMultiVariadicSubSymbol(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () {
+ // CHECK: "test.mixed_variadic_in5"(%arg2, %arg3, %arg1) <{attr1 = 2 : i32, pattern_name = "MatchMultiVariadicSubSymbol"}> : (i32, i32, i32) -> ()
+ "test.mixed_variadic_in6"(%arg0, %arg1, %arg2, %arg3) {attr1 = 2 : i32} : (i32, i32, i32, i32) -> ()
+
+ return
+}
+
//===----------------------------------------------------------------------===//
// Test that natives calls are only called once during rewrites.
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 8b5ef5c6e01829..875f5b71de7f3c 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -103,8 +103,9 @@ class PatternEmitter {
// DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
// bound name and the constraint of the operand respectively.
void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
- DagLeaf operandMatcher, StringRef argName,
- int argIndex);
+ int operandIndex, DagLeaf operandMatcher,
+ StringRef argName, int argIndex,
+ std::optional<int> variadicSubIndex);
// Emits C++ statements for matching the operands which can be matched in
// either order.
@@ -112,6 +113,11 @@ class PatternEmitter {
StringRef opName, int argIndex, int &operandIndex,
int depth);
+ // Emits C++ statements for matching a variadic operand.
+ void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree,
+ StringRef opName, int argIndex,
+ int &operandIndex, int depth);
+
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an attribute.
void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
@@ -262,6 +268,11 @@ class StaticMatcherHelper {
// Determine if we should inline the match logic or delegate to a static
// function.
bool useStaticMatcher(DagNode node) {
+ // either/variadic node must be associated to the parentOp, thus we can't
+ // emit a static matcher rooted at them.
+ if (node.isEither() || node.isVariadic())
+ return false;
+
return refStats[node] > kStaticMatcherThreshold;
}
@@ -462,6 +473,8 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
if (argTree.isEither())
PrintFatalError(loc, "NativeCodeCall cannot have `either` operands");
+ if (argTree.isVariadic())
+ PrintFatalError(loc, "NativeCodeCall cannot have `variadic` operands");
os << "::mlir::Value " << argName << ";\n";
} else {
@@ -596,6 +609,18 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
continue;
}
if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
+ if (argTree.isVariadic()) {
+ if (!operand->isVariadic()) {
+ auto error = formatv("variadic DAG construct can't match op {0}'s "
+ "non-variadic operand #{1}",
+ op.getOperationName(), opArgIdx);
+ PrintFatalError(loc, error);
+ }
+ emitVariadicOperandMatch(tree, argTree, castedName, opArgIdx,
+ nextOperand, depth);
+ ++nextOperand;
+ continue;
+ }
if (operand->isVariableLength()) {
auto error = formatv("use nested DAG construct to match op {0}'s "
"variadic operand #{1} unsupported now",
@@ -627,9 +652,10 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
if (opArg.is<NamedTypeConstraint *>()) {
auto operandName =
formatv("{0}.getODSOperands({1})", castedName, nextOperand);
- emitOperandMatch(tree, castedName, operandName.str(),
+ emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
/*operandMatcher=*/tree.getArgAsLeaf(i),
- /*argName=*/tree.getArgName(i), opArgIdx);
+ /*argName=*/tree.getArgName(i), opArgIdx,
+ /*variadicSubIndex=*/std::nullopt);
++nextOperand;
} else if (opArg.is<NamedAttribute *>()) {
emitAttributeMatch(tree, opName, opArgIdx, depth);
@@ -643,11 +669,12 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
}
void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
- StringRef operandName,
+ StringRef operandName, int operandIndex,
DagLeaf operandMatcher, StringRef argName,
- int argIndex) {
+ int argIndex,
+ std::optional<int> variadicSubIndex) {
Operator &op = tree.getDialectOp(opMap);
- auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
+ auto *operand = op.getArg(operandIndex).get<NamedTypeConstraint *>();
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
@@ -682,7 +709,11 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
// Capture the value
// `$_` is a special symbol to ignore op argument matching.
if (!argName.empty() && argName != "_") {
- auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex);
+ auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
+ variadicSubIndex);
+ if (res == symbolInfoMap.end())
+ PrintFatalError(loc, formatv("symbol not found: {0}", argName));
+
os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName);
}
}
@@ -735,8 +766,10 @@ void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
} else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
+ operandIndex,
/*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
- /*argName=*/eitherArgTree.getArgName(i), argIndex);
+ /*argName=*/eitherArgTree.getArgName(i), argIndex,
+ /*variadicSubIndex=*/std::nullopt);
++operandIndex;
} else {
PrintFatalError(loc, "either can only be applied on operand");
@@ -764,6 +797,67 @@ void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
os.unindent().unindent() << "}\n";
}
+void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
+ DagNode variadicArgTree,
+ StringRef opName, int argIndex,
+ int &operandIndex, int depth) {
+ Operator &op = tree.getDialectOp(opMap);
+
+ os << "{\n";
+ os.indent();
+
+ os << formatv("auto variadic_operand_range = {0}.getODSOperands({1});\n",
+ opName, operandIndex);
+ os << formatv("if (variadic_operand_range.size() != {0}) "
+ "return ::mlir::failure();\n",
+ variadicArgTree.getNumArgs());
+
+ StringRef variadicTreeName = variadicArgTree.getSymbol();
+ if (!variadicTreeName.empty()) {
+ auto res =
+ symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
+ /*variadicSubIndex=*/std::nullopt);
+ if (res == symbolInfoMap.end())
+ PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));
+
+ os << formatv("{0} = variadic_operand_range;\n",
+ res->second.getVarName(variadicTreeName));
+ }
+
+ for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) {
+ if (DagNode argTree = variadicArgTree.getArgAsNestedDag(i)) {
+ if (!argTree.isOperation())
+ PrintFatalError(loc, "variadic only accepts operation sub-dags");
+
+ os << "{\n";
+ os.indent();
+
+ std::string argName = formatv("local_op_{0}", i).str();
+ os << formatv("auto *{0} = "
+ "variadic_operand_range[{1}].getDefiningOp();\n",
+ argName, i);
+ emitMatchCheck(
+ opName, /*matchStr=*/argName,
+ formatv("\"There's no operation that defines variadic operand "
+ "{0} (variadic sub-opearnd #{1}) of {2}\"",
+ operandIndex, i, opName));
+ emitMatch(argTree, argName, depth + 1);
+ os << formatv("tblgen_ops.push_back({0});\n", argName);
+
+ os.unindent() << "}\n";
+ } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
+ auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i);
+ emitOperandMatch(tree, opName, operandName.str(), operandIndex,
+ /*operandMatcher=*/variadicArgTree.getArgAsLeaf(i),
+ /*argName=*/variadicArgTree.getArgName(i), argIndex, i);
+ } else {
+ PrintFatalError(loc, "variadic can only be applied on operand");
+ }
+ }
+
+ os.unindent() << "}\n";
+}
+
void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap);
More information about the Mlir-commits
mailing list