[Mlir-commits] [mlir] 2d99c81 - [mlir-tblgen] Support `either` in Tablegen DRR.

Chia-hung Duan llvmlistbot at llvm.org
Mon Nov 8 15:17:11 PST 2021


Author: Chia-hung Duan
Date: 2021-11-08T23:16:03Z
New Revision: 2d99c815d7c2f40d9be1270b276768374291b68e

URL: https://github.com/llvm/llvm-project/commit/2d99c815d7c2f40d9be1270b276768374291b68e
DIFF: https://github.com/llvm/llvm-project/commit/2d99c815d7c2f40d9be1270b276768374291b68e.diff

LOG: [mlir-tblgen] Support `either` in Tablegen DRR.

Add a new directive `either` to specify the operands can be matched in either order

Reviewed By: jpienaar, Mogball

Differential Revision: https://reviews.llvm.org/D110666

Added: 
    

Modified: 
    mlir/docs/DeclarativeRewrites.md
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/Pattern.h
    mlir/lib/TableGen/Pattern.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index 94c0100d4197..7a9a001ca264 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -774,6 +774,23 @@ Explicitly-specified return types will take precedence over return types
 inferred from op traits or user-defined builders. The return types of values
 replacing root op results cannot be overridden.
 
+### `either`
+
+The `either` directive is used to specify the operands may be matched in either
+order.
+
+```tablegen
+def : Pat<(TwoArgOp (either $firstArg, (AnOp $secondArg))),
+          (...)>;
+```
+
+The above pattern will accept either `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and
+`"test.TwoArgOp"(%AnOpArg, %I32Arg)`.
+
+Only operand is supported with `either` and note that an operation with
+`Commutative` trait doesn't imply that it'll have the same behavior than
+`either` while pattern matching.
+
 ## Debugging Tips
 
 ### Run `mlir-tblgen` to see the generated content

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 37bf1d233c2b..7ba93ad2476d 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2730,6 +2730,21 @@ def location;
 
 def returnType;
 
+// Directive used to specify the operands may be matched in either order. When
+// two adjacents are marked with `either`, it'll try to match the operands in
+// either ordering of constraints. Example:
+//
+// ```
+// (TwoArgOp (either $firstArg, (AnOp $secondArg)))
+// ```
+// The above pattern will accept either `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and
+// `"test.TwoArgOp"(%AnOpArg, %I32Arg)`.
+//
+// Only operand is supported with `either` and note that an operation with
+// `Commutative` trait doesn't imply that it'll have the same behavior than
+// `either` while pattern matching.
+def either;
+
 //===----------------------------------------------------------------------===//
 // Attribute and Type generation
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 834ebdace12d..ac8c98e5a490 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -186,6 +186,9 @@ class DagNode {
   // Returns true if this DAG node is wrapping native code call.
   bool isNativeCodeCall() const;
 
+  // Returns whether this DAG is an `either` specifier.
+  bool isEither() const;
+
   // Returns true if this DAG node is an operation.
   bool isOperation() const;
 

diff  --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 084345238785..148ca49e65e0 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -113,7 +113,7 @@ bool DagNode::isNativeCodeCall() const {
 
 bool DagNode::isOperation() const {
   return !isNativeCodeCall() && !isReplaceWithValue() &&
-         !isLocationDirective() && !isReturnTypeDirective();
+         !isLocationDirective() && !isReturnTypeDirective() && !isEither();
 }
 
 llvm::StringRef DagNode::getNativeCodeTemplate() const {
@@ -142,7 +142,9 @@ Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
 }
 
 int DagNode::getNumOps() const {
-  int count = isReplaceWithValue() ? 0 : 1;
+  // We want to get number of operations recursively involved in the DAG tree.
+  // All other directives should be excluded.
+  int count = isOperation() ? 1 : 0;
   for (int i = 0, e = getNumArgs(); i != e; ++i) {
     if (auto child = getArgAsNestedDag(i))
       count += child.getNumOps();
@@ -184,6 +186,11 @@ bool DagNode::isReturnTypeDirective() const {
   return dagOpDef->getName() == "returnType";
 }
 
+bool DagNode::isEither() const {
+  auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
+  return dagOpDef->getName() == "either";
+}
+
 void DagNode::print(raw_ostream &os) const {
   if (node)
     node->print(os);
@@ -764,22 +771,25 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
   if (tree.isOperation()) {
     auto &op = getDialectOp(tree);
     auto numOpArgs = op.getNumArgs();
+    int numEither = 0;
 
-    // The pattern might have trailing directives.
+    // We need to exclude the trailing directives and `either` directive groups
+    // two operands of the operation.
     int numDirectives = 0;
     for (int i = numTreeArgs - 1; i >= 0; --i) {
       if (auto dagArg = tree.getArgAsNestedDag(i)) {
         if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
           ++numDirectives;
-        else
-          break;
+        else if (dagArg.isEither())
+          ++numEither;
       }
     }
 
-    if (numOpArgs != numTreeArgs - numDirectives) {
-      auto err = formatv("op '{0}' argument number mismatch: "
-                         "{1} in pattern vs. {2} in definition",
-                         op.getOperationName(), numTreeArgs, numOpArgs);
+    if (numOpArgs != numTreeArgs - numDirectives + numEither) {
+      auto err =
+          formatv("op '{0}' argument number mismatch: "
+                  "{1} in pattern vs. {2} in definition",
+                  op.getOperationName(), numTreeArgs + numEither, numOpArgs);
       PrintFatalError(&def, err);
     }
 
@@ -791,10 +801,30 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
       verifyBind(infoMap.bindOpResult(treeName, op), treeName);
     }
 
-    for (int i = 0; i != numTreeArgs; ++i) {
+    // The operand in `either` DAG should be bound to the operation in the
+    // parent DagNode.
+    auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
+                                     int &opArgIdx) {
+      for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
+        if (DagNode subTree = tree.getArgAsNestedDag(i)) {
+          collectBoundSymbols(subTree, infoMap, isSrcPattern);
+        } else {
+          auto argName = tree.getArgName(i);
+          if (!argName.empty() && argName != "_")
+            verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
+                       argName);
+        }
+      }
+    };
+
+    for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
       if (auto treeArg = tree.getArgAsNestedDag(i)) {
-        // This DAG node argument is a DAG node itself. Go inside recursively.
-        collectBoundSymbols(treeArg, infoMap, isSrcPattern);
+        if (treeArg.isEither()) {
+          collectSymbolInEither(tree, treeArg, opArgIdx);
+        } else {
+          // This DAG node argument is a DAG node itself. Go inside recursively.
+          collectBoundSymbols(treeArg, infoMap, isSrcPattern);
+        }
         continue;
       }
 
@@ -806,7 +836,7 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
         if (!treeArgName.empty() && treeArgName != "_") {
           LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
                                   << treeArgName << '\n');
-          verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i),
+          verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
                      treeArgName);
         }
       }

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c2b408244d08..098937d1d74c 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1328,6 +1328,30 @@ def : Pat<(OneI32ResultOp),
               (replaceWithValue $results__2),
               ConstantAttr<I32Attr, "2">)>;
 
+//===----------------------------------------------------------------------===//
+// Test Patterns (either)
+
+def TestEitherOpA : TEST_Op<"either_op_a"> {
+  let arguments = (ins AnyInteger:$arg0, AnyInteger:$arg1, AnyInteger:$arg2);
+  let results = (outs I32:$output);
+}
+
+def TestEitherOpB : TEST_Op<"either_op_b"> {
+  let arguments = (ins AnyInteger:$arg0);
+  let results = (outs I32:$output);
+}
+
+def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $_),
+          (TestEitherOpB $arg2)>;
+
+def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1), I16:$arg2), $_),
+          (TestEitherOpB $arg2)>;
+
+def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1),
+                                 (TestEitherOpB I16:$arg2)),
+                          $_),
+          (TestEitherOpB $arg2)>;
+
 //===----------------------------------------------------------------------===//
 // Test Patterns (Location)
 

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index ab436fa33bd9..65bf11d8d14c 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -531,6 +531,40 @@ func @redundantTest(%arg0: i32) -> i32 {
   return %0 : i32
 }
 
+//===----------------------------------------------------------------------===//
+// Test either directive
+//===----------------------------------------------------------------------===//
+
+// CHECK: @either_dag_leaf_only
+func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
+  // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+  %0 = "test.either_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32
+  // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+  %1 = "test.either_op_a"(%arg1, %arg0, %arg2) : (i16, i32, i8) -> i32
+  return
+}
+
+// CHECK: @either_dag_leaf_dag_node
+func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
+  %0 = "test.either_op_b"(%arg0) : (i32) -> i32
+  // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+  %1 = "test.either_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32
+  // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+  %2 = "test.either_op_a"(%arg1, %0, %arg2) : (i16, i32, i8) -> i32
+  return
+}
+
+// CHECK: @either_dag_node_dag_node
+func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
+  %0 = "test.either_op_b"(%arg0) : (i32) -> i32
+  %1 = "test.either_op_b"(%arg1) : (i16) -> i32
+  // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+  %2 = "test.either_op_a"(%0, %1, %arg2) : (i32, i32, i8) -> i32
+  // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
+  %3 = "test.either_op_a"(%1, %0, %arg2) : (i32, i32, i8) -> i32
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // Test that ops without type deduction can be created with type builders.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 318f006cdf45..886fd1d07ac4 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -117,10 +117,17 @@ class PatternEmitter {
   void emitOpMatch(DagNode tree, StringRef opName, int depth);
 
   // Emits C++ statements for matching the `argIndex`-th argument of the given
-  // DAG `tree` as an operand. operandIndex is the index in the DAG excluding
-  // the preceding attributes.
-  void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
-                        int operandIndex, int depth);
+  // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
+  // bound name and the constraint of the operand respectively.
+  void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
+                        DagLeaf operandMatcher, StringRef argName,
+                        int argIndex);
+
+  // Emits C++ statements for matching the operands which can be matched in
+  // either order.
+  void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
+                              StringRef opName, int argIndex, int &operandIndex,
+                              int depth);
 
   // Emits C++ statements for matching the `argIndex`-th argument of the given
   // DAG `tree` as an attribute.
@@ -470,6 +477,9 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
     std::string argName = formatv("arg{0}_{1}", depth, i);
     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
+      if (argTree.isEither())
+        PrintFatalError(loc, "NativeCodeCall cannot have `either` operands");
+
       os << "Value " << argName << ";\n";
     } else {
       auto leaf = tree.getArgAsLeaf(i);
@@ -584,12 +594,6 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
                    formatv("\"{0} is not {1} type\"", castedName,
                            op.getQualCppClassName()));
 
-  if (tree.getNumArgs() != op.getNumArgs())
-    PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
-                                 "pattern vs. {2} in definition",
-                                 op.getOperationName(), tree.getNumArgs(),
-                                 op.getNumArgs()));
-
   // If the operand's name is set, set to that variable.
   auto name = tree.getSymbol();
   if (!name.empty())
@@ -601,6 +605,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
 
     // Handle nested DAG construct first
     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
+      if (argTree.isEither()) {
+        emitEitherOperandMatch(tree, argTree, castedName, i, nextOperand,
+                               depth);
+        continue;
+      }
       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
         if (operand->isVariableLength()) {
           auto error = formatv("use nested DAG construct to match op {0}'s "
@@ -609,6 +618,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
           PrintFatalError(loc, error);
         }
       }
+
       os << "{\n";
 
       // Attributes don't count for getODSOperands.
@@ -618,9 +628,10 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
           "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
           argName, castedName, nextOperand);
       // Null check of operand's definingOp
-      emitMatchCheck(castedName, /*matchStr=*/argName,
-                     formatv("\"Operand {0} of {1} has null definingOp\"",
-                             nextOperand++, castedName));
+      emitMatchCheck(
+          castedName, /*matchStr=*/argName,
+          formatv("\"There's no operation that defines operand {0} of {1}\"",
+                  nextOperand++, castedName));
       emitMatch(argTree, argName, depth + 1);
       os << formatv("tblgen_ops.push_back({0});\n", argName);
       os.unindent() << "}\n";
@@ -629,8 +640,12 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
 
     // Next handle DAG leaf: operand or attribute
     if (opArg.is<NamedTypeConstraint *>()) {
-      // emitOperandMatch's argument indexing counts attributes.
-      emitOperandMatch(tree, castedName, i, nextOperand, depth);
+      auto operandName =
+          formatv("{0}.getODSOperands({1})", castedName, nextOperand);
+      emitOperandMatch(tree, castedName, operandName.str(),
+                       /*operandMatcher=*/tree.getArgAsLeaf(i),
+                       /*argName=*/tree.getArgName(i),
+                       /*argIndex=*/i);
       ++nextOperand;
     } else if (opArg.is<NamedAttribute *>()) {
       emitAttributeMatch(tree, opName, i, depth);
@@ -644,24 +659,23 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
 }
 
 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
-                                      int argIndex, int operandIndex,
-                                      int depth) {
+                                      StringRef operandName,
+                                      DagLeaf operandMatcher, StringRef argName,
+                                      int argIndex) {
   Operator &op = tree.getDialectOp(opMap);
   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
-  auto matcher = tree.getArgAsLeaf(argIndex);
 
   // If a constraint is specified, we need to generate C++ statements to
   // check the constraint.
-  if (!matcher.isUnspecified()) {
-    if (!matcher.isOperandMatcher()) {
+  if (!operandMatcher.isUnspecified()) {
+    if (!operandMatcher.isOperandMatcher())
       PrintFatalError(
           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
                        op.getOperationName(), argIndex + 1));
-    }
 
     // Only need to verify if the matcher's type is 
diff erent from the one
     // of op definition.
-    Constraint constraint = matcher.getAsConstraint();
+    Constraint constraint = operandMatcher.getAsConstraint();
     if (operand->constraint != constraint) {
       if (operand->isVariableLength()) {
         auto error = formatv(
@@ -669,36 +683,93 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
             op.getOperationName(), argIndex);
         PrintFatalError(loc, error);
       }
-      auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
-                          opName, operandIndex);
+      auto self = formatv("(*{0}.begin()).getType()", operandName);
       StringRef verifier = staticMatcherHelper.getVerifierName(constraint);
       emitStaticVerifierCall(
           verifier, opName, self.str(),
           formatv(
               "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
-              operandIndex, op.getOperationName(),
+              operand - op.operand_begin(), op.getOperationName(),
               escapeString(constraint.getSummary()))
               .str());
     }
   }
 
   // Capture the value
-  auto name = tree.getArgName(argIndex);
   // `$_` is a special symbol to ignore op argument matching.
-  if (!name.empty() && name != "_") {
-    // We need to subtract the number of attributes before this operand to get
-    // the index in the operand list.
-    auto numPrevAttrs = std::count_if(
-        op.arg_begin(), op.arg_begin() + argIndex,
-        [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
-
-    auto res = symbolInfoMap.findBoundSymbol(name, tree, op, argIndex);
-    os << formatv("{0} = {1}.getODSOperands({2});\n",
-                  res->second.getVarName(name), opName,
-                  argIndex - numPrevAttrs);
+  if (!argName.empty() && argName != "_") {
+    auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex);
+    os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName);
   }
 }
 
+void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
+                                            StringRef opName, int argIndex,
+                                            int &operandIndex, int depth) {
+  constexpr int numEitherArgs = 2;
+  if (eitherArgTree.getNumArgs() != numEitherArgs)
+    PrintFatalError(loc, "`either` only supports grouping two operands");
+
+  Operator &op = tree.getDialectOp(opMap);
+
+  std::string codeBuffer;
+  llvm::raw_string_ostream tblgenOps(codeBuffer);
+
+  std::string lambda = formatv("eitherLambda{0}", depth);
+  os << formatv("auto {0} = [&](OperandRange v0, OperandRange v1) {{\n",
+                lambda);
+
+  os.indent();
+
+  for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
+    if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) {
+      if (argTree.isEither())
+        PrintFatalError(loc, "either cannot be nested");
+
+      std::string argName = formatv("local_op_{0}", i).str();
+
+      os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName,
+                    i);
+      emitMatchCheck(
+          opName, /*matchStr=*/argName,
+          formatv("\"There's no operation that defines operand {0} of {1}\"",
+                  operandIndex++, opName));
+      emitMatch(argTree, argName, depth + 1);
+      // `tblgen_ops` is used to collect the matched operations. In either, we
+      // need to queue the operation only if the matching success. Thus we emit
+      // the code at the end.
+      tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
+    } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
+      emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
+                       /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
+                       /*argName=*/eitherArgTree.getArgName(i), argIndex);
+      ++operandIndex;
+    } else {
+      PrintFatalError(loc, "either can only be applied on operand");
+    }
+  }
+
+  os << tblgenOps.str();
+  os << "return success();\n";
+  os.unindent() << "};\n";
+
+  os << "{\n";
+  os.indent();
+
+  os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName,
+                operandIndex - 2);
+  os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName,
+                operandIndex - 1);
+
+  os << formatv("if(failed({0}(eitherOperand0, eitherOperand1)) && "
+                "failed({0}(eitherOperand1, "
+                "eitherOperand0)))\n",
+                lambda);
+  os.indent() << "return failure();\n";
+
+  os.unindent().unindent() << "}\n";
+}
+
 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
                                         int argIndex, int depth) {
   Operator &op = tree.getDialectOp(opMap);


        


More information about the Mlir-commits mailing list