[Mlir-commits] [mlir] 32032cb - [mlir][tblgen] Fix emitting wrong index for `either` directive.

Chia-hung Duan llvmlistbot at llvm.org
Wed May 3 06:11:19 PDT 2023


Author: Chia-hung Duan
Date: 2023-05-03T13:07:45Z
New Revision: 32032cbf25748fcfd8cac9bf6cb0d153dfe151a2

URL: https://github.com/llvm/llvm-project/commit/32032cbf25748fcfd8cac9bf6cb0d153dfe151a2
DIFF: https://github.com/llvm/llvm-project/commit/32032cbf25748fcfd8cac9bf6cb0d153dfe151a2.diff

LOG: [mlir][tblgen] Fix emitting wrong index for `either` directive.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D149152

Added: 
    

Modified: 
    mlir/lib/TableGen/Pattern.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 1faea137a301b..e8625b2e6b710 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -807,15 +807,16 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
     // The operand in `either` DAG should be bound to the operation in the
     // parent DagNode.
     auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
-                                     int &opArgIdx) {
+                                     int opArgIdx) {
       for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
         if (DagNode subTree = tree.getArgAsNestedDag(i)) {
           collectBoundSymbols(subTree, infoMap, isSrcPattern);
         } else {
           auto argName = tree.getArgName(i);
-          if (!argName.empty() && argName != "_")
+          if (!argName.empty() && argName != "_") {
             verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
                        argName);
+          }
         }
       }
     };
@@ -824,6 +825,14 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
       if (auto treeArg = tree.getArgAsNestedDag(i)) {
         if (treeArg.isEither()) {
           collectSymbolInEither(tree, treeArg, opArgIdx);
+          // `either` DAG is *flattened*. For example,
+          //
+          //  (FooOp (either arg0, arg1), arg2)
+          //
+          //  can be viewed as:
+          //
+          //  (FooOp arg0, arg1, arg2)
+          ++opArgIdx;
         } else {
           // This DAG node argument is a DAG node itself. Go inside recursively.
           collectBoundSymbols(treeArg, infoMap, isSrcPattern);

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 29017581b1f68..60faf6dfe0e89 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1691,20 +1691,20 @@ def TestEitherOpA : TEST_Op<"either_op_a"> {
 }
 
 def TestEitherOpB : TEST_Op<"either_op_b"> {
-  let arguments = (ins AnyInteger:$arg0);
+  let arguments = (ins AnyInteger:$arg0, AnyInteger:$arg1);
   let results = (outs I32:$output);
 }
 
-def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $_),
-          (TestEitherOpB $arg2)>;
+def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x),
+          (TestEitherOpB $arg2, $x)>;
 
-def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1), I16:$arg2), $_),
-          (TestEitherOpB $arg2)>;
+def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_), I16:$arg2), $x),
+          (TestEitherOpB $arg2, $x)>;
 
-def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1),
-                                 (TestEitherOpB I16:$arg2)),
-                          $_),
-          (TestEitherOpB $arg2)>;
+def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
+                                 (TestEitherOpB I16:$arg2, $_)),
+                          $x),
+          (TestEitherOpB $arg2, $x)>;
 
 //===----------------------------------------------------------------------===//
 // Test Patterns (Location)

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index d20ffbe46caaa..4c1182fa9eb72 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -532,30 +532,30 @@ func.func @redundantTest(%arg0: i32) -> i32 {
 
 // CHECK: @either_dag_leaf_only
 func.func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
-  // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+  // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
   %0 = "test.either_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32
-  // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+  // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
   %1 = "test.either_op_a"(%arg1, %arg0, %arg2) : (i16, i32, i8) -> i32
   return
 }
 
 // CHECK: @either_dag_leaf_dag_node
 func.func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
-  %0 = "test.either_op_b"(%arg0) : (i32) -> i32
-  // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+  %0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32
+  // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
   %1 = "test.either_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32
-  // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+  // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
   %2 = "test.either_op_a"(%arg1, %0, %arg2) : (i16, i32, i8) -> i32
   return
 }
 
 // CHECK: @either_dag_node_dag_node
 func.func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
-  %0 = "test.either_op_b"(%arg0) : (i32) -> i32
-  %1 = "test.either_op_b"(%arg1) : (i16) -> i32
-  // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+  %0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32
+  %1 = "test.either_op_b"(%arg1, %arg1) : (i16, i16) -> i32
+  // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
   %2 = "test.either_op_a"(%0, %1, %arg2) : (i32, i32, i8) -> i32
-  // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+  // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
   %3 = "test.either_op_a"(%1, %0, %arg2) : (i32, i32, i8) -> i32
   return
 }

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index abbe8400b5fe0..8a04cc9f5be11 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -582,22 +582,24 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
   if (!name.empty())
     os << formatv("{0} = {1};\n", name, castedName);
 
-  for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) {
-    auto opArg = op.getArg(i);
+  for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
+       ++i, ++opArgIdx) {
+    auto opArg = op.getArg(opArgIdx);
     std::string argName = formatv("op{0}", depth + 1);
 
     // Handle nested DAG construct first
     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
       if (argTree.isEither()) {
-        emitEitherOperandMatch(tree, argTree, castedName, i, nextOperand,
+        emitEitherOperandMatch(tree, argTree, castedName, opArgIdx, nextOperand,
                                depth);
+        ++opArgIdx;
         continue;
       }
       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
         if (operand->isVariableLength()) {
           auto error = formatv("use nested DAG construct to match op {0}'s "
                                "variadic operand #{1} unsupported now",
-                               op.getOperationName(), i);
+                               op.getOperationName(), opArgIdx);
           PrintFatalError(loc, error);
         }
       }
@@ -627,11 +629,10 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
           formatv("{0}.getODSOperands({1})", castedName, nextOperand);
       emitOperandMatch(tree, castedName, operandName.str(),
                        /*operandMatcher=*/tree.getArgAsLeaf(i),
-                       /*argName=*/tree.getArgName(i),
-                       /*argIndex=*/i);
+                       /*argName=*/tree.getArgName(i), opArgIdx);
       ++nextOperand;
     } else if (opArg.is<NamedAttribute *>()) {
-      emitAttributeMatch(tree, opName, i, depth);
+      emitAttributeMatch(tree, opName, opArgIdx, depth);
     } else {
       PrintFatalError(loc, "unhandled case when matching op");
     }


        


More information about the Mlir-commits mailing list