[Mlir-commits] [mlir] 32032cb - [mlir][tblgen] Fix emitting wrong index for `either` directive.
Chia-hung Duan
llvmlistbot at llvm.org
Wed May 3 06:11:19 PDT 2023
Author: Chia-hung Duan
Date: 2023-05-03T13:07:45Z
New Revision: 32032cbf25748fcfd8cac9bf6cb0d153dfe151a2
URL: https://github.com/llvm/llvm-project/commit/32032cbf25748fcfd8cac9bf6cb0d153dfe151a2
DIFF: https://github.com/llvm/llvm-project/commit/32032cbf25748fcfd8cac9bf6cb0d153dfe151a2.diff
LOG: [mlir][tblgen] Fix emitting wrong index for `either` directive.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D149152
Added:
Modified:
mlir/lib/TableGen/Pattern.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 1faea137a301b..e8625b2e6b710 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -807,15 +807,16 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
// The operand in `either` DAG should be bound to the operation in the
// parent DagNode.
auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
- int &opArgIdx) {
+ int opArgIdx) {
for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
if (DagNode subTree = tree.getArgAsNestedDag(i)) {
collectBoundSymbols(subTree, infoMap, isSrcPattern);
} else {
auto argName = tree.getArgName(i);
- if (!argName.empty() && argName != "_")
+ if (!argName.empty() && argName != "_") {
verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
argName);
+ }
}
}
};
@@ -824,6 +825,14 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
if (auto treeArg = tree.getArgAsNestedDag(i)) {
if (treeArg.isEither()) {
collectSymbolInEither(tree, treeArg, opArgIdx);
+ // `either` DAG is *flattened*. For example,
+ //
+ // (FooOp (either arg0, arg1), arg2)
+ //
+ // can be viewed as:
+ //
+ // (FooOp arg0, arg1, arg2)
+ ++opArgIdx;
} else {
// This DAG node argument is a DAG node itself. Go inside recursively.
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 29017581b1f68..60faf6dfe0e89 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1691,20 +1691,20 @@ def TestEitherOpA : TEST_Op<"either_op_a"> {
}
def TestEitherOpB : TEST_Op<"either_op_b"> {
- let arguments = (ins AnyInteger:$arg0);
+ let arguments = (ins AnyInteger:$arg0, AnyInteger:$arg1);
let results = (outs I32:$output);
}
-def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $_),
- (TestEitherOpB $arg2)>;
+def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x),
+ (TestEitherOpB $arg2, $x)>;
-def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1), I16:$arg2), $_),
- (TestEitherOpB $arg2)>;
+def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_), I16:$arg2), $x),
+ (TestEitherOpB $arg2, $x)>;
-def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1),
- (TestEitherOpB I16:$arg2)),
- $_),
- (TestEitherOpB $arg2)>;
+def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
+ (TestEitherOpB I16:$arg2, $_)),
+ $x),
+ (TestEitherOpB $arg2, $x)>;
//===----------------------------------------------------------------------===//
// Test Patterns (Location)
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index d20ffbe46caaa..4c1182fa9eb72 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -532,30 +532,30 @@ func.func @redundantTest(%arg0: i32) -> i32 {
// CHECK: @either_dag_leaf_only
func.func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
- // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+ // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
%0 = "test.either_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32
- // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+ // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
%1 = "test.either_op_a"(%arg1, %arg0, %arg2) : (i16, i32, i8) -> i32
return
}
// CHECK: @either_dag_leaf_dag_node
func.func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
- %0 = "test.either_op_b"(%arg0) : (i32) -> i32
- // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+ %0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32
+ // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
%1 = "test.either_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32
- // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+ // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
%2 = "test.either_op_a"(%arg1, %0, %arg2) : (i16, i32, i8) -> i32
return
}
// CHECK: @either_dag_node_dag_node
func.func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
- %0 = "test.either_op_b"(%arg0) : (i32) -> i32
- %1 = "test.either_op_b"(%arg1) : (i16) -> i32
- // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+ %0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32
+ %1 = "test.either_op_b"(%arg1, %arg1) : (i16, i16) -> i32
+ // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
%2 = "test.either_op_a"(%0, %1, %arg2) : (i32, i32, i8) -> i32
- // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+ // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
%3 = "test.either_op_a"(%1, %0, %arg2) : (i32, i32, i8) -> i32
return
}
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index abbe8400b5fe0..8a04cc9f5be11 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -582,22 +582,24 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
if (!name.empty())
os << formatv("{0} = {1};\n", name, castedName);
- for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) {
- auto opArg = op.getArg(i);
+ for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
+ ++i, ++opArgIdx) {
+ auto opArg = op.getArg(opArgIdx);
std::string argName = formatv("op{0}", depth + 1);
// Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
if (argTree.isEither()) {
- emitEitherOperandMatch(tree, argTree, castedName, i, nextOperand,
+ emitEitherOperandMatch(tree, argTree, castedName, opArgIdx, nextOperand,
depth);
+ ++opArgIdx;
continue;
}
if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
if (operand->isVariableLength()) {
auto error = formatv("use nested DAG construct to match op {0}'s "
"variadic operand #{1} unsupported now",
- op.getOperationName(), i);
+ op.getOperationName(), opArgIdx);
PrintFatalError(loc, error);
}
}
@@ -627,11 +629,10 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
formatv("{0}.getODSOperands({1})", castedName, nextOperand);
emitOperandMatch(tree, castedName, operandName.str(),
/*operandMatcher=*/tree.getArgAsLeaf(i),
- /*argName=*/tree.getArgName(i),
- /*argIndex=*/i);
+ /*argName=*/tree.getArgName(i), opArgIdx);
++nextOperand;
} else if (opArg.is<NamedAttribute *>()) {
- emitAttributeMatch(tree, opName, i, depth);
+ emitAttributeMatch(tree, opName, opArgIdx, depth);
} else {
PrintFatalError(loc, "unhandled case when matching op");
}
More information about the Mlir-commits
mailing list