[Mlir-commits] [mlir] b4001ae - [mlir-tblgen] Fix failed matching when binds same operand of an op in different depth

Chia-hung Duan llvmlistbot at llvm.org
Tue Jul 20 00:45:25 PDT 2021


Author: Chia-hung Duan
Date: 2021-07-20T15:43:09+08:00
New Revision: b4001ae8851f47406f8187b72b0253b21bf1da4c

URL: https://github.com/llvm/llvm-project/commit/b4001ae8851f47406f8187b72b0253b21bf1da4c
DIFF: https://github.com/llvm/llvm-project/commit/b4001ae8851f47406f8187b72b0253b21bf1da4c.diff

LOG: [mlir-tblgen] Fix failed matching when binds same operand of an op in different depth

For example, we will generate incorrect code for the pattern,

def : Pat<((FooOp (FooOp, $a, $b), $b)), (...)>;

We didn't allow $b to be bond twice with same operand of same op.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D105677

Added: 
    

Modified: 
    mlir/include/mlir/TableGen/Pattern.h
    mlir/lib/TableGen/Pattern.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 98c5d9b18f5dd..4b397c7c02bf6 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -184,6 +184,9 @@ class DagNode {
   void print(raw_ostream &os) const;
 
 private:
+  friend class SymbolInfoMap;
+  const void *getAsOpaquePointer() const { return node; }
+
   const llvm::DagInit *node; // nullptr means null DagNode
 };
 
@@ -237,6 +240,10 @@ class SymbolInfoMap {
     // Allow SymbolInfoMap to access private methods.
     friend class SymbolInfoMap;
 
+    // DagNode and DagLeaf are accessed by value which means it can't be used as
+    // identifier here. Use an opaque pointer type instead.
+    using DagAndIndex = std::pair<const void *, int>;
+
     // What kind of entity this symbol represents:
     // * Attr: op attribute
     // * Operand: op operand
@@ -244,19 +251,21 @@ class SymbolInfoMap {
     // * Value: a value not attached to an op (e.g., from NativeCodeCall)
     enum class Kind : uint8_t { Attr, Operand, Result, Value };
 
-    // Creates a SymbolInfo instance. `index` is only used for `Attr` and
-    // `Operand` so should be negative for `Result` and `Value` kind.
-    SymbolInfo(const Operator *op, Kind kind, Optional<int> index);
+    // Creates a SymbolInfo instance. `dagAndIndex` is only used for `Attr` and
+    // `Operand` so should be llvm::None for `Result` and `Value` kind.
+    SymbolInfo(const Operator *op, Kind kind,
+               Optional<DagAndIndex> dagAndIndex);
 
     // Static methods for creating SymbolInfo.
     static SymbolInfo getAttr(const Operator *op, int index) {
-      return SymbolInfo(op, Kind::Attr, index);
+      return SymbolInfo(op, Kind::Attr, DagAndIndex(nullptr, index));
     }
     static SymbolInfo getAttr() {
       return SymbolInfo(nullptr, Kind::Attr, llvm::None);
     }
-    static SymbolInfo getOperand(const Operator *op, int index) {
-      return SymbolInfo(op, Kind::Operand, index);
+    static SymbolInfo getOperand(DagNode node, const Operator *op, int index) {
+      return SymbolInfo(op, Kind::Operand,
+                        DagAndIndex(node.getAsOpaquePointer(), index));
     }
     static SymbolInfo getResult(const Operator *op) {
       return SymbolInfo(op, Kind::Result, llvm::None);
@@ -291,8 +300,11 @@ class SymbolInfoMap {
 
     const Operator *op; // The op where the bound entity belongs
     Kind kind;          // The kind of the bound entity
-    // The argument index (for `Attr` and `Operand` only)
-    Optional<int> argIndex;
+    // The pair of DagNode pointer and argument index (for `Attr` and `Operand`
+    // only). Note that operands may be bound to the same symbol, use the
+    // DagNode and index to distinguish them. For `Attr`, the Dag part will be
+    // nullptr.
+    Optional<DagAndIndex> dagAndIndex;
     // Alternative name for the symbol. It is used in case the name
     // is not unique. Applicable for `Operand` only.
     Optional<std::string> alternativeName;
@@ -312,7 +324,8 @@ class SymbolInfoMap {
 
   // Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
   // Returns false if `symbol` is already bound and symbols are not operands.
-  bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
+  bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op,
+                      int argIndex);
 
   // Binds the given `symbol` to the results the given `op`. Returns false if
   // `symbol` is already bound.
@@ -334,8 +347,8 @@ class SymbolInfoMap {
 
   // Returns an iterator to the information of the given symbol named as `key`,
   // with index `argIndex` for operator `op`.
-  const_iterator findBoundSymbol(StringRef key, const Operator &op,
-                                 int argIndex) const;
+  const_iterator findBoundSymbol(StringRef key, DagNode node,
+                                 const Operator &op, int argIndex) const;
 
   // Returns the bounds of a range that includes all the elements which
   // bind to the `key`.

diff  --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index d3bd6f7662bff..a9b03519fb540 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -193,8 +193,8 @@ StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
 }
 
 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
-                                      Optional<int> index)
-    : op(op), kind(kind), argIndex(index) {}
+                                      Optional<DagAndIndex> dagAndIndex)
+    : op(op), kind(kind), dagAndIndex(dagAndIndex) {}
 
 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
   switch (kind) {
@@ -217,8 +217,9 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
   switch (kind) {
   case Kind::Attr: {
     if (op) {
-      auto type =
-          op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
+      auto type = op->getArg((*dagAndIndex).second)
+                      .get<NamedAttribute *>()
+                      ->attr.getStorageType();
       return std::string(formatv("{0} {1};\n", type, name));
     }
     // TODO(suderman): Use a more exact type when available.
@@ -254,7 +255,8 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
   }
   case Kind::Operand: {
     assert(index < 0);
-    auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
+    auto *operand =
+        op->getArg((*dagAndIndex).second).get<NamedTypeConstraint *>();
     // If this operand is variadic, then return a range. Otherwise, return the
     // value itself.
     if (operand->isVariableLength()) {
@@ -355,8 +357,8 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
   llvm_unreachable("unknown kind");
 }
 
-bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
-                                   int argIndex) {
+bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
+                                   const Operator &op, int argIndex) {
   StringRef name = getValuePackName(symbol);
   if (name != symbol) {
     auto error = formatv(
@@ -366,7 +368,7 @@ bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
 
   auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
                      ? SymbolInfo::getAttr(&op, argIndex)
-                     : SymbolInfo::getOperand(&op, argIndex);
+                     : SymbolInfo::getOperand(node, &op, argIndex);
 
   std::string key = symbol.str();
   if (symbolInfoMap.count(key)) {
@@ -414,13 +416,15 @@ SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
 }
 
 SymbolInfoMap::const_iterator
-SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
+SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
                                int argIndex) const {
   std::string name = getValuePackName(key).str();
   auto range = symbolInfoMap.equal_range(name);
 
+  const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex);
+
   for (auto it = range.first; it != range.second; ++it) {
-    if (it->second.op == &op && it->second.argIndex == argIndex) {
+    if (it->second.dagAndIndex == symbolInfo.dagAndIndex) {
       return it;
     }
   }
@@ -722,7 +726,8 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
         if (!treeArgName.empty() && treeArgName != "_") {
           LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
                                   << treeArgName << '\n');
-          verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName);
+          verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i),
+                     treeArgName);
         }
       }
     }

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 16e141ede1736..c94ca64ba6958 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -716,6 +716,13 @@ def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
 def TestNestedOpEqualArgsPattern :
   Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
 
+// Test when equality is enforced on same op and same operand but at 
diff erent
+// depth. We only bound one of the $x to the second operand of outer OpN and
+// left another be the default value (which is the value of first operand of
+// outer OpN). As a result, it ended up comparing wrong values in some cases.
+def TestNestedSameOpAndSameArgEqualityPattern :
+  Pat<(OpN (OpN $_, $x), $x), (replaceWithValue $x)>;
+
 // Test multiple equal arguments check enforced.
 def TestMultipleEqualArgsPattern :
   Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index affc3d7a93968..69140dfb3fd47 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -158,6 +158,17 @@ func @verifyNestedOpEqualArgs(
   return
 }
 
+// CHECK-LABEL: verifyNestedSameOpAndSameArgEquality
+func @verifyNestedSameOpAndSameArgEquality(%arg0: i32, %arg1: i32) -> i32 {
+  // def TestNestedSameOpAndSameArgEqualityPattern:
+  //   Pat<(OpN (OpN $_, $x), $x), (replaceWithValue $x)>;
+
+  %0 = "test.op_n"(%arg1, %arg0) : (i32, i32) -> (i32)
+  %1 = "test.op_n"(%0, %arg0) : (i32, i32) -> (i32)
+  // CHECK: return %arg0 : i32
+  return %1 : i32
+}
+
 // CHECK-LABEL: verifyMultipleEqualArgs
 func @verifyMultipleEqualArgs(
   %arg0: i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) {

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 611bc5c1c05ea..9d3e4a93b53d6 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -454,7 +454,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
         op.arg_begin(), op.arg_begin() + argIndex,
         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
 
-    auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
+    auto res = symbolInfoMap.findBoundSymbol(name, tree, op, argIndex);
     os << formatv("{0} = {1}.getODSOperands({2});\n",
                   res->second.getVarName(name), opName,
                   argIndex - numPrevAttrs);


        


More information about the Mlir-commits mailing list