[Mlir-commits] [mlir] 7271c1b - [DDR] Introduce implicit equality check for the source pattern operands with the same name.

Jacques Pienaar llvmlistbot at llvm.org
Tue Oct 13 16:17:32 PDT 2020


Author: rdzhabarov
Date: 2020-10-13T16:05:14-07:00
New Revision: 7271c1bcb96051bcd227d3fa6071a620fe238850

URL: https://github.com/llvm/llvm-project/commit/7271c1bcb96051bcd227d3fa6071a620fe238850
DIFF: https://github.com/llvm/llvm-project/commit/7271c1bcb96051bcd227d3fa6071a620fe238850.diff

LOG: [DDR] Introduce implicit equality check for the source pattern operands with the same name.

This CL allows user to specify the same name for the operands in the source pattern which implicitly enforces equality on operands with the same name.
E.g., Pat<(OpA $a, $b, $a) ... > would create a matching rule for checking equality for the first and the last operands. Equality of the operands is enforced at any depth, e.g., OpA ($a, $b, OpB($a, $c, OpC ($a))).

Example usage: Pat<(Reshape $arg0, (Shape $arg0)), (replaceWithValue $arg0)>

Note, this feature only covers operands but not attributes.
Current use cases are based on the operand equality and explicitly add the constraint into the pattern. Attribute equality will be worked out on the different CL.

Differential Revision: https://reviews.llvm.org/D89254

Added: 
    

Modified: 
    mlir/include/mlir/TableGen/Pattern.h
    mlir/lib/TableGen/Pattern.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index a5759e358f69..4fc2ae762a66 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -21,6 +21,8 @@
 #include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringSet.h"
 
+#include <unordered_map>
+
 namespace llvm {
 class DagInit;
 class Init;
@@ -228,6 +230,9 @@ class SymbolInfoMap {
     // value bound by this symbol.
     std::string getVarDecl(StringRef name) const;
 
+    // Returns a variable name for the symbol named as `name`.
+    std::string getVarName(StringRef name) const;
+
   private:
     // Allow SymbolInfoMap to access private methods.
     friend class SymbolInfoMap;
@@ -285,9 +290,12 @@ class SymbolInfoMap {
     Kind kind;          // The kind of the bound entity
     // The argument index (for `Attr` and `Operand` only)
     Optional<int> argIndex;
+    // Alternative name for the symbol. It is used in case the name
+    // is not unique. Applicable for `Operand` only.
+    Optional<std::string> alternativeName;
   };
 
-  using BaseT = llvm::StringMap<SymbolInfo>;
+  using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
 
   // Iterators for accessing all symbols.
   using iterator = BaseT::iterator;
@@ -300,7 +308,7 @@ class SymbolInfoMap {
   const_iterator end() const { return symbolInfoMap.end(); }
 
   // Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
-  // Returns false if `symbol` is already bound.
+  // Returns false if `symbol` is already bound and symbols are not operands.
   bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
 
   // Binds the given `symbol` to the results the given `op`. Returns false if
@@ -317,6 +325,18 @@ class SymbolInfoMap {
   // Returns an iterator to the information of the given symbol named as `key`.
   const_iterator find(StringRef key) const;
 
+  // Returns an iterator to the information of the given symbol named as `key`,
+  // with index `argIndex` for operator `op`.
+  const_iterator findBoundSymbol(StringRef key, const Operator &op,
+                                 int argIndex) const;
+
+  // Returns the bounds of a range that includes all the elements which
+  // bind to the `key`.
+  std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
+
+  // Returns number of times symbol named as `key` was used.
+  int count(StringRef key) const;
+
   // Returns the number of static values of the given `symbol` corresponds to.
   // A static value is an operand/result declared in ODS. Normally a symbol only
   // represents one static value, but symbols bound to op results can represent
@@ -338,6 +358,9 @@ class SymbolInfoMap {
   std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
                              const char *separator = ", ") const;
 
+  // Assign alternative unique names to Operands that have equal names.
+  void assignUniqueAlternativeNames();
+
   // Splits the given `symbol` into a value pack name and an index. Returns the
   // value pack name and writes the index to `index` on success. Returns
   // `symbol` itself if it does not contain an index.
@@ -347,7 +370,7 @@ class SymbolInfoMap {
   static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
 
 private:
-  llvm::StringMap<SymbolInfo> symbolInfoMap;
+  BaseT symbolInfoMap;
 
   // Pattern instantiation location. This is intended to be used as parameter
   // to PrintFatalError() to report errors.

diff  --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index cfa3da2c417a..5170f07870ab 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -208,6 +208,10 @@ int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
   llvm_unreachable("unknown kind");
 }
 
+std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
+  return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
+}
+
 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
   switch (kind) {
@@ -219,8 +223,9 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
   case Kind::Operand: {
     // Use operand range for captured operands (to support potential variadic
     // operands).
-    return std::string(formatv(
-        "::mlir::Operation::operand_range {0}(op0->getOperands());\n", name));
+    return std::string(
+        formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
+                getVarName(name)));
   }
   case Kind::Value: {
     return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name));
@@ -359,16 +364,34 @@ bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
                      ? SymbolInfo::getAttr(&op, argIndex)
                      : SymbolInfo::getOperand(&op, argIndex);
 
-  return symbolInfoMap.insert({symbol, symInfo}).second;
+  std::string key = symbol.str();
+  if (auto numberOfEntries = symbolInfoMap.count(key)) {
+    // Only non unique name for the operand is supported.
+    if (symInfo.kind != SymbolInfo::Kind::Operand) {
+      return false;
+    }
+
+    // Cannot add new operand if there is already non operand with the same
+    // name.
+    if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
+      return false;
+    }
+  }
+
+  symbolInfoMap.emplace(key, symInfo);
+  return true;
 }
 
 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
   StringRef name = getValuePackName(symbol);
-  return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
+  auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
+
+  return symbolInfoMap.count(inserted->first) == 1;
 }
 
 bool SymbolInfoMap::bindValue(StringRef symbol) {
-  return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
+  auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getValue());
+  return symbolInfoMap.count(inserted->first) == 1;
 }
 
 bool SymbolInfoMap::contains(StringRef symbol) const {
@@ -376,10 +399,38 @@ bool SymbolInfoMap::contains(StringRef symbol) const {
 }
 
 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
-  StringRef name = getValuePackName(key);
+  std::string name = getValuePackName(key).str();
+
   return symbolInfoMap.find(name);
 }
 
+SymbolInfoMap::const_iterator
+SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
+                               int argIndex) const {
+  std::string name = getValuePackName(key).str();
+  auto range = symbolInfoMap.equal_range(name);
+
+  for (auto it = range.first; it != range.second; ++it) {
+    if (it->second.op == &op && it->second.argIndex == argIndex) {
+      return it;
+    }
+  }
+
+  return symbolInfoMap.end();
+}
+
+std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
+SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
+  std::string name = getValuePackName(key).str();
+
+  return symbolInfoMap.equal_range(name);
+}
+
+int SymbolInfoMap::count(StringRef key) const {
+  std::string name = getValuePackName(key).str();
+  return symbolInfoMap.count(name);
+}
+
 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
   StringRef name = getValuePackName(symbol);
   if (name != symbol) {
@@ -388,7 +439,7 @@ int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
     return 1;
   }
   // Otherwise, find how many it represents by querying the symbol's info.
-  return find(name)->getValue().getStaticValueCount();
+  return find(name)->second.getStaticValueCount();
 }
 
 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
@@ -397,13 +448,13 @@ std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
   int index = -1;
   StringRef name = getValuePackName(symbol, &index);
 
-  auto it = symbolInfoMap.find(name);
+  auto it = symbolInfoMap.find(name.str());
   if (it == symbolInfoMap.end()) {
     auto error = formatv("referencing unbound symbol '{0}'", symbol);
     PrintFatalError(loc, error);
   }
 
-  return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
+  return it->second.getValueAndRangeUse(name, index, fmt, separator);
 }
 
 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
@@ -411,13 +462,44 @@ std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
   int index = -1;
   StringRef name = getValuePackName(symbol, &index);
 
-  auto it = symbolInfoMap.find(name);
+  auto it = symbolInfoMap.find(name.str());
   if (it == symbolInfoMap.end()) {
     auto error = formatv("referencing unbound symbol '{0}'", symbol);
     PrintFatalError(loc, error);
   }
 
-  return it->getValue().getAllRangeUse(name, index, fmt, separator);
+  return it->second.getAllRangeUse(name, index, fmt, separator);
+}
+
+void SymbolInfoMap::assignUniqueAlternativeNames() {
+  llvm::StringSet<> usedNames;
+
+  for (auto symbolInfoIt = symbolInfoMap.begin();
+       symbolInfoIt != symbolInfoMap.end();) {
+    auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
+    auto startRange = range.first;
+    auto endRange = range.second;
+
+    auto operandName = symbolInfoIt->first;
+    int startSearchIndex = 0;
+    for (++startRange; startRange != endRange; ++startRange) {
+      // Current operand name is not unique, find a unique one
+      // and set the alternative name.
+      for (int i = startSearchIndex;; ++i) {
+        std::string alternativeName = operandName + std::to_string(i);
+        if (!usedNames.contains(alternativeName) &&
+            symbolInfoMap.count(alternativeName) == 0) {
+          usedNames.insert(alternativeName);
+          startRange->second.alternativeName = alternativeName;
+          startSearchIndex = i + 1;
+
+          break;
+        }
+      }
+    }
+
+    symbolInfoIt = endRange;
+  }
 }
 
 //===----------------------------------------------------------------------===//
@@ -445,6 +527,10 @@ void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
+
+  LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
+  infoMap.assignUniqueAlternativeNames();
+  LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
 }
 
 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index d36d7bd58ea8..aef39a9e19fe 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -619,6 +619,32 @@ def OpM : TEST_Op<"op_m"> {
   let results = (outs I32);
 }
 
+def OpN : TEST_Op<"op_n"> {
+  let arguments = (ins I32, I32);
+  let results = (outs I32);
+}
+
+def OpO : TEST_Op<"op_o"> {
+  let arguments = (ins I32);
+  let results = (outs I32);
+}
+
+def OpP : TEST_Op<"op_p"> {
+  let arguments = (ins I32, I32, I32, I32, I32, I32);
+  let results = (outs I32);
+}
+
+// Test same operand name enforces equality condition check.
+def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
+
+// Test when equality is enforced at 
diff erent depth.
+def TestNestedOpEqualArgsPattern :
+  Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
+
+// Test multiple equal arguments check enforced.
+def TestMultipleEqualArgsPattern :
+  Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
+
 // Test for memrefs normalization of an op with normalizable memrefs.
 def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
   let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 0f2f434c928f..5986be6240f9 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -111,6 +111,64 @@ func @verifyManyArgs(%arg: i32) {
   return
 }
 
+// CHECK-LABEL: verifyEqualArgs
+func @verifyEqualArgs(%arg0: i32, %arg1: i32) {
+  // def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
+
+  // CHECK: "test.op_o"(%arg0) : (i32) -> i32
+  "test.op_n"(%arg0, %arg0) : (i32, i32) -> (i32)
+
+  // CHECK: "test.op_n"(%arg0, %arg1) : (i32, i32) -> i32
+  "test.op_n"(%arg0, %arg1) : (i32, i32) -> (i32)
+
+  return
+}
+
+// CHECK-LABEL: verifyNestedOpEqualArgs
+func @verifyNestedOpEqualArgs(
+  %arg0: i32, %arg1: i32, %arg2 : i32, %arg3 : i32, %arg4 : i32, %arg5 : i32) {
+  // def TestNestedOpEqualArgsPattern :
+  //   Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
+
+  // CHECK: %arg1
+  %0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
+    : (i32, i32, i32, i32, i32, i32) -> (i32)
+  %1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)
+
+  // CHECK: test.op_p
+  // CHECK: test.op_n
+  %2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
+    : (i32, i32, i32, i32, i32, i32) -> (i32)
+  %3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)
+
+  return
+}
+
+// CHECK-LABEL: verifyMultipleEqualArgs
+func @verifyMultipleEqualArgs(
+  %arg0: i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) {
+  // def TestMultipleEqualArgsPattern :
+  //   Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
+
+  // CHECK: "test.op_n"(%arg2, %arg1) : (i32, i32) -> i32
+  "test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg1, %arg2) :
+    (i32, i32, i32, i32 , i32, i32) -> i32
+
+  // CHECK: test.op_p
+  "test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg0, %arg2) :
+    (i32, i32, i32, i32 , i32, i32) -> i32
+
+  // CHECK: test.op_p
+  "test.op_p"(%arg0, %arg1, %arg1, %arg0, %arg1, %arg2) :
+    (i32, i32, i32, i32 , i32, i32) -> i32
+
+   // CHECK: test.op_p
+  "test.op_p"(%arg0, %arg1, %arg2, %arg2, %arg3, %arg4) :
+    (i32, i32, i32, i32 , i32, i32) -> i32
+
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // Test Symbol Binding
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index ff6138f73914..7bff3e3b40b6 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -89,6 +89,11 @@ class PatternEmitter {
   void emitMatchCheck(int depth, const FmtObjectBase &matchFmt,
                       const llvm::formatv_object_base &failureFmt);
 
+  // Emits C++ for checking a match with a corresponding match failure
+  // diagnostics.
+  void emitMatchCheck(int depth, const std::string &matchStr,
+                      const std::string &failureStr);
+
   //===--------------------------------------------------------------------===//
   // Rewrite utilities
   //===--------------------------------------------------------------------===//
@@ -327,8 +332,9 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
         op.arg_begin(), op.arg_begin() + argIndex,
         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
 
-    os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth,
-                  argIndex - numPrevAttrs);
+    auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
+    os << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
+                  res->second.getVarName(name), depth, argIndex - numPrevAttrs);
   }
 }
 
@@ -393,10 +399,15 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
 void PatternEmitter::emitMatchCheck(
     int depth, const FmtObjectBase &matchFmt,
     const llvm::formatv_object_base &failureFmt) {
-  os << "if (!(" << matchFmt.str() << "))";
+  emitMatchCheck(depth, matchFmt.str(), failureFmt.str());
+}
+
+void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr,
+                                    const std::string &failureStr) {
+  os << "if (!(" << matchStr << "))";
   os.scope("{\n", "\n}\n").os
       << "return rewriter.notifyMatchFailure(op" << depth
-      << ", [&](::mlir::Diagnostic &diag) {\n  diag << " << failureFmt.str()
+      << ", [&](::mlir::Diagnostic &diag) {\n  diag << " << failureStr
       << ";\n});";
 }
 
@@ -445,6 +456,30 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
                              constraint.getDescription()));
     }
   }
+
+  // Some of the operands could be bound to the same symbol name, we need
+  // to enforce equality constraint on those.
+  // TODO: we should be able to emit equality checks early
+  // and short circuit unnecessary work if vars are not equal.
+  for (auto symbolInfoIt = symbolInfoMap.begin();
+       symbolInfoIt != symbolInfoMap.end();) {
+    auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
+    auto startRange = range.first;
+    auto endRange = range.second;
+
+    auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
+    for (++startRange; startRange != endRange; ++startRange) {
+      auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
+      emitMatchCheck(
+          depth,
+          formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
+          formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
+                  secondOperand));
+    }
+
+    symbolInfoIt = endRange;
+  }
+
   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
 }
 
@@ -518,8 +553,9 @@ void PatternEmitter::emit(StringRef rewriteName) {
       // Create local variables for storing the arguments and results bound
       // to symbols.
       for (const auto &symbolInfoPair : symbolInfoMap) {
-        StringRef symbol = symbolInfoPair.getKey();
-        auto &info = symbolInfoPair.getValue();
+        const auto &symbol = symbolInfoPair.first;
+        const auto &info = symbolInfoPair.second;
+
         os << info.getVarDecl(symbol);
       }
       // TODO: capture ops with consistent numbering so that it can be
@@ -1093,7 +1129,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
       os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
                     range);
     } else {
-      os << formatv("tblgen_values.push_back(", varName);
+      os << formatv("tblgen_values.push_back(");
       if (node.isNestedDagArg(argIndex)) {
         os << symbolInfoMap.getValueAndRangeUse(
             childNodeNames.lookup(argIndex));


        


More information about the Mlir-commits mailing list