[Mlir-commits] [mlir] 2d99c81 - [mlir-tblgen] Support `either` in Tablegen DRR.
Chia-hung Duan
llvmlistbot at llvm.org
Mon Nov 8 15:17:11 PST 2021
Author: Chia-hung Duan
Date: 2021-11-08T23:16:03Z
New Revision: 2d99c815d7c2f40d9be1270b276768374291b68e
URL: https://github.com/llvm/llvm-project/commit/2d99c815d7c2f40d9be1270b276768374291b68e
DIFF: https://github.com/llvm/llvm-project/commit/2d99c815d7c2f40d9be1270b276768374291b68e.diff
LOG: [mlir-tblgen] Support `either` in Tablegen DRR.
Add a new directive `either` to specify the operands can be matched in either order
Reviewed By: jpienaar, Mogball
Differential Revision: https://reviews.llvm.org/D110666
Added:
Modified:
mlir/docs/DeclarativeRewrites.md
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Pattern.h
mlir/lib/TableGen/Pattern.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index 94c0100d4197..7a9a001ca264 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -774,6 +774,23 @@ Explicitly-specified return types will take precedence over return types
inferred from op traits or user-defined builders. The return types of values
replacing root op results cannot be overridden.
+### `either`
+
+The `either` directive is used to specify the operands may be matched in either
+order.
+
+```tablegen
+def : Pat<(TwoArgOp (either $firstArg, (AnOp $secondArg))),
+ (...)>;
+```
+
+The above pattern will accept either `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and
+`"test.TwoArgOp"(%AnOpArg, %I32Arg)`.
+
+Only operand is supported with `either` and note that an operation with
+`Commutative` trait doesn't imply that it'll have the same behavior than
+`either` while pattern matching.
+
## Debugging Tips
### Run `mlir-tblgen` to see the generated content
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 37bf1d233c2b..7ba93ad2476d 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2730,6 +2730,21 @@ def location;
def returnType;
+// Directive used to specify the operands may be matched in either order. When
+// two adjacents are marked with `either`, it'll try to match the operands in
+// either ordering of constraints. Example:
+//
+// ```
+// (TwoArgOp (either $firstArg, (AnOp $secondArg)))
+// ```
+// The above pattern will accept either `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and
+// `"test.TwoArgOp"(%AnOpArg, %I32Arg)`.
+//
+// Only operand is supported with `either` and note that an operation with
+// `Commutative` trait doesn't imply that it'll have the same behavior than
+// `either` while pattern matching.
+def either;
+
//===----------------------------------------------------------------------===//
// Attribute and Type generation
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 834ebdace12d..ac8c98e5a490 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -186,6 +186,9 @@ class DagNode {
// Returns true if this DAG node is wrapping native code call.
bool isNativeCodeCall() const;
+ // Returns whether this DAG is an `either` specifier.
+ bool isEither() const;
+
// Returns true if this DAG node is an operation.
bool isOperation() const;
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 084345238785..148ca49e65e0 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -113,7 +113,7 @@ bool DagNode::isNativeCodeCall() const {
bool DagNode::isOperation() const {
return !isNativeCodeCall() && !isReplaceWithValue() &&
- !isLocationDirective() && !isReturnTypeDirective();
+ !isLocationDirective() && !isReturnTypeDirective() && !isEither();
}
llvm::StringRef DagNode::getNativeCodeTemplate() const {
@@ -142,7 +142,9 @@ Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
}
int DagNode::getNumOps() const {
- int count = isReplaceWithValue() ? 0 : 1;
+ // We want to get number of operations recursively involved in the DAG tree.
+ // All other directives should be excluded.
+ int count = isOperation() ? 1 : 0;
for (int i = 0, e = getNumArgs(); i != e; ++i) {
if (auto child = getArgAsNestedDag(i))
count += child.getNumOps();
@@ -184,6 +186,11 @@ bool DagNode::isReturnTypeDirective() const {
return dagOpDef->getName() == "returnType";
}
+bool DagNode::isEither() const {
+ auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
+ return dagOpDef->getName() == "either";
+}
+
void DagNode::print(raw_ostream &os) const {
if (node)
node->print(os);
@@ -764,22 +771,25 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
if (tree.isOperation()) {
auto &op = getDialectOp(tree);
auto numOpArgs = op.getNumArgs();
+ int numEither = 0;
- // The pattern might have trailing directives.
+ // We need to exclude the trailing directives and `either` directive groups
+ // two operands of the operation.
int numDirectives = 0;
for (int i = numTreeArgs - 1; i >= 0; --i) {
if (auto dagArg = tree.getArgAsNestedDag(i)) {
if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
++numDirectives;
- else
- break;
+ else if (dagArg.isEither())
+ ++numEither;
}
}
- if (numOpArgs != numTreeArgs - numDirectives) {
- auto err = formatv("op '{0}' argument number mismatch: "
- "{1} in pattern vs. {2} in definition",
- op.getOperationName(), numTreeArgs, numOpArgs);
+ if (numOpArgs != numTreeArgs - numDirectives + numEither) {
+ auto err =
+ formatv("op '{0}' argument number mismatch: "
+ "{1} in pattern vs. {2} in definition",
+ op.getOperationName(), numTreeArgs + numEither, numOpArgs);
PrintFatalError(&def, err);
}
@@ -791,10 +801,30 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
verifyBind(infoMap.bindOpResult(treeName, op), treeName);
}
- for (int i = 0; i != numTreeArgs; ++i) {
+ // The operand in `either` DAG should be bound to the operation in the
+ // parent DagNode.
+ auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
+ int &opArgIdx) {
+ for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
+ if (DagNode subTree = tree.getArgAsNestedDag(i)) {
+ collectBoundSymbols(subTree, infoMap, isSrcPattern);
+ } else {
+ auto argName = tree.getArgName(i);
+ if (!argName.empty() && argName != "_")
+ verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
+ argName);
+ }
+ }
+ };
+
+ for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
if (auto treeArg = tree.getArgAsNestedDag(i)) {
- // This DAG node argument is a DAG node itself. Go inside recursively.
- collectBoundSymbols(treeArg, infoMap, isSrcPattern);
+ if (treeArg.isEither()) {
+ collectSymbolInEither(tree, treeArg, opArgIdx);
+ } else {
+ // This DAG node argument is a DAG node itself. Go inside recursively.
+ collectBoundSymbols(treeArg, infoMap, isSrcPattern);
+ }
continue;
}
@@ -806,7 +836,7 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
if (!treeArgName.empty() && treeArgName != "_") {
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
<< treeArgName << '\n');
- verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i),
+ verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
treeArgName);
}
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c2b408244d08..098937d1d74c 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1328,6 +1328,30 @@ def : Pat<(OneI32ResultOp),
(replaceWithValue $results__2),
ConstantAttr<I32Attr, "2">)>;
+//===----------------------------------------------------------------------===//
+// Test Patterns (either)
+
+def TestEitherOpA : TEST_Op<"either_op_a"> {
+ let arguments = (ins AnyInteger:$arg0, AnyInteger:$arg1, AnyInteger:$arg2);
+ let results = (outs I32:$output);
+}
+
+def TestEitherOpB : TEST_Op<"either_op_b"> {
+ let arguments = (ins AnyInteger:$arg0);
+ let results = (outs I32:$output);
+}
+
+def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $_),
+ (TestEitherOpB $arg2)>;
+
+def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1), I16:$arg2), $_),
+ (TestEitherOpB $arg2)>;
+
+def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1),
+ (TestEitherOpB I16:$arg2)),
+ $_),
+ (TestEitherOpB $arg2)>;
+
//===----------------------------------------------------------------------===//
// Test Patterns (Location)
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index ab436fa33bd9..65bf11d8d14c 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -531,6 +531,40 @@ func @redundantTest(%arg0: i32) -> i32 {
return %0 : i32
}
+//===----------------------------------------------------------------------===//
+// Test either directive
+//===----------------------------------------------------------------------===//
+
+// CHECK: @either_dag_leaf_only
+func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
+ // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+ %0 = "test.either_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32
+ // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+ %1 = "test.either_op_a"(%arg1, %arg0, %arg2) : (i16, i32, i8) -> i32
+ return
+}
+
+// CHECK: @either_dag_leaf_dag_node
+func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
+ %0 = "test.either_op_b"(%arg0) : (i32) -> i32
+ // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+ %1 = "test.either_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32
+ // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+ %2 = "test.either_op_a"(%arg1, %0, %arg2) : (i16, i32, i8) -> i32
+ return
+}
+
+// CHECK: @either_dag_node_dag_node
+func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
+ %0 = "test.either_op_b"(%arg0) : (i32) -> i32
+ %1 = "test.either_op_b"(%arg1) : (i16) -> i32
+ // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+ %2 = "test.either_op_a"(%0, %1, %arg2) : (i32, i32, i8) -> i32
+ // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+ %3 = "test.either_op_a"(%1, %0, %arg2) : (i32, i32, i8) -> i32
+ return
+}
+
//===----------------------------------------------------------------------===//
// Test that ops without type deduction can be created with type builders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 318f006cdf45..886fd1d07ac4 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -117,10 +117,17 @@ class PatternEmitter {
void emitOpMatch(DagNode tree, StringRef opName, int depth);
// Emits C++ statements for matching the `argIndex`-th argument of the given
- // DAG `tree` as an operand. operandIndex is the index in the DAG excluding
- // the preceding attributes.
- void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
- int operandIndex, int depth);
+ // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
+ // bound name and the constraint of the operand respectively.
+ void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
+ DagLeaf operandMatcher, StringRef argName,
+ int argIndex);
+
+ // Emits C++ statements for matching the operands which can be matched in
+ // either order.
+ void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
+ StringRef opName, int argIndex, int &operandIndex,
+ int depth);
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an attribute.
@@ -470,6 +477,9 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
std::string argName = formatv("arg{0}_{1}", depth, i);
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
+ if (argTree.isEither())
+ PrintFatalError(loc, "NativeCodeCall cannot have `either` operands");
+
os << "Value " << argName << ";\n";
} else {
auto leaf = tree.getArgAsLeaf(i);
@@ -584,12 +594,6 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
formatv("\"{0} is not {1} type\"", castedName,
op.getQualCppClassName()));
- if (tree.getNumArgs() != op.getNumArgs())
- PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
- "pattern vs. {2} in definition",
- op.getOperationName(), tree.getNumArgs(),
- op.getNumArgs()));
-
// If the operand's name is set, set to that variable.
auto name = tree.getSymbol();
if (!name.empty())
@@ -601,6 +605,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
// Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
+ if (argTree.isEither()) {
+ emitEitherOperandMatch(tree, argTree, castedName, i, nextOperand,
+ depth);
+ continue;
+ }
if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
if (operand->isVariableLength()) {
auto error = formatv("use nested DAG construct to match op {0}'s "
@@ -609,6 +618,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
PrintFatalError(loc, error);
}
}
+
os << "{\n";
// Attributes don't count for getODSOperands.
@@ -618,9 +628,10 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
"(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
argName, castedName, nextOperand);
// Null check of operand's definingOp
- emitMatchCheck(castedName, /*matchStr=*/argName,
- formatv("\"Operand {0} of {1} has null definingOp\"",
- nextOperand++, castedName));
+ emitMatchCheck(
+ castedName, /*matchStr=*/argName,
+ formatv("\"There's no operation that defines operand {0} of {1}\"",
+ nextOperand++, castedName));
emitMatch(argTree, argName, depth + 1);
os << formatv("tblgen_ops.push_back({0});\n", argName);
os.unindent() << "}\n";
@@ -629,8 +640,12 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
// Next handle DAG leaf: operand or attribute
if (opArg.is<NamedTypeConstraint *>()) {
- // emitOperandMatch's argument indexing counts attributes.
- emitOperandMatch(tree, castedName, i, nextOperand, depth);
+ auto operandName =
+ formatv("{0}.getODSOperands({1})", castedName, nextOperand);
+ emitOperandMatch(tree, castedName, operandName.str(),
+ /*operandMatcher=*/tree.getArgAsLeaf(i),
+ /*argName=*/tree.getArgName(i),
+ /*argIndex=*/i);
++nextOperand;
} else if (opArg.is<NamedAttribute *>()) {
emitAttributeMatch(tree, opName, i, depth);
@@ -644,24 +659,23 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
}
void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
- int argIndex, int operandIndex,
- int depth) {
+ StringRef operandName,
+ DagLeaf operandMatcher, StringRef argName,
+ int argIndex) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
- auto matcher = tree.getArgAsLeaf(argIndex);
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
- if (!matcher.isUnspecified()) {
- if (!matcher.isOperandMatcher()) {
+ if (!operandMatcher.isUnspecified()) {
+ if (!operandMatcher.isOperandMatcher())
PrintFatalError(
loc, formatv("the {1}-th argument of op '{0}' should be an operand",
op.getOperationName(), argIndex + 1));
- }
// Only need to verify if the matcher's type is
diff erent from the one
// of op definition.
- Constraint constraint = matcher.getAsConstraint();
+ Constraint constraint = operandMatcher.getAsConstraint();
if (operand->constraint != constraint) {
if (operand->isVariableLength()) {
auto error = formatv(
@@ -669,36 +683,93 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
op.getOperationName(), argIndex);
PrintFatalError(loc, error);
}
- auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
- opName, operandIndex);
+ auto self = formatv("(*{0}.begin()).getType()", operandName);
StringRef verifier = staticMatcherHelper.getVerifierName(constraint);
emitStaticVerifierCall(
verifier, opName, self.str(),
formatv(
"\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
- operandIndex, op.getOperationName(),
+ operand - op.operand_begin(), op.getOperationName(),
escapeString(constraint.getSummary()))
.str());
}
}
// Capture the value
- auto name = tree.getArgName(argIndex);
// `$_` is a special symbol to ignore op argument matching.
- if (!name.empty() && name != "_") {
- // We need to subtract the number of attributes before this operand to get
- // the index in the operand list.
- auto numPrevAttrs = std::count_if(
- op.arg_begin(), op.arg_begin() + argIndex,
- [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
-
- auto res = symbolInfoMap.findBoundSymbol(name, tree, op, argIndex);
- os << formatv("{0} = {1}.getODSOperands({2});\n",
- res->second.getVarName(name), opName,
- argIndex - numPrevAttrs);
+ if (!argName.empty() && argName != "_") {
+ auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex);
+ os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName);
}
}
+void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
+ StringRef opName, int argIndex,
+ int &operandIndex, int depth) {
+ constexpr int numEitherArgs = 2;
+ if (eitherArgTree.getNumArgs() != numEitherArgs)
+ PrintFatalError(loc, "`either` only supports grouping two operands");
+
+ Operator &op = tree.getDialectOp(opMap);
+
+ std::string codeBuffer;
+ llvm::raw_string_ostream tblgenOps(codeBuffer);
+
+ std::string lambda = formatv("eitherLambda{0}", depth);
+ os << formatv("auto {0} = [&](OperandRange v0, OperandRange v1) {{\n",
+ lambda);
+
+ os.indent();
+
+ for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
+ if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) {
+ if (argTree.isEither())
+ PrintFatalError(loc, "either cannot be nested");
+
+ std::string argName = formatv("local_op_{0}", i).str();
+
+ os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName,
+ i);
+ emitMatchCheck(
+ opName, /*matchStr=*/argName,
+ formatv("\"There's no operation that defines operand {0} of {1}\"",
+ operandIndex++, opName));
+ emitMatch(argTree, argName, depth + 1);
+ // `tblgen_ops` is used to collect the matched operations. In either, we
+ // need to queue the operation only if the matching success. Thus we emit
+ // the code at the end.
+ tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
+ } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
+ emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
+ /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
+ /*argName=*/eitherArgTree.getArgName(i), argIndex);
+ ++operandIndex;
+ } else {
+ PrintFatalError(loc, "either can only be applied on operand");
+ }
+ }
+
+ os << tblgenOps.str();
+ os << "return success();\n";
+ os.unindent() << "};\n";
+
+ os << "{\n";
+ os.indent();
+
+ os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName,
+ operandIndex - 2);
+ os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName,
+ operandIndex - 1);
+
+ os << formatv("if(failed({0}(eitherOperand0, eitherOperand1)) && "
+ "failed({0}(eitherOperand1, "
+ "eitherOperand0)))\n",
+ lambda);
+ os.indent() << "return failure();\n";
+
+ os.unindent().unindent() << "}\n";
+}
+
void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap);
More information about the Mlir-commits
mailing list