[Mlir-commits] [mlir] 0b793c4 - Revert "[DDR] Introduce implicit equality check for the source pattern operands with the same name."

Mehdi Amini llvmlistbot at llvm.org
Tue Oct 13 17:38:28 PDT 2020


Author: Mehdi Amini
Date: 2020-10-14T00:37:10Z
New Revision: 0b793c4be0eee90a22b7a150187f5f7cf744c120

URL: https://github.com/llvm/llvm-project/commit/0b793c4be0eee90a22b7a150187f5f7cf744c120
DIFF: https://github.com/llvm/llvm-project/commit/0b793c4be0eee90a22b7a150187f5f7cf744c120.diff

LOG: Revert "[DDR] Introduce implicit equality check for the source pattern operands with the same name."

This reverts commit 7271c1bcb96051bcd227d3fa6071a620fe238850.

This broke the gcc-5 build:

/usr/include/c++/5/ext/new_allocator.h:120:4: error: no matching function for call to 'std::pair<const std::__cxx11::basic_string<char>, mlir::tblgen::SymbolInfoMap::SymbolInfo>::pair(llvm::StringRef&, mlir::tblgen::SymbolInfoMap::SymbolInfo)'
  { ::new((void *)__p) _Up(std::forward<_Args>(__args)...); }
    ^
In file included from /usr/include/c++/5/utility:70:0,
                 from llvm/include/llvm/Support/type_traits.h:18,
                 from llvm/include/llvm/Support/Casting.h:18,
                 from mlir/include/mlir/Support/LLVM.h:24,
                 from mlir/include/mlir/TableGen/Pattern.h:17,
                 from mlir/lib/TableGen/Pattern.cpp:14:
/usr/include/c++/5/bits/stl_pair.h:206:9: note: candidate: template<class ... _Args1, long unsigned int ..._Indexes1, class ... _Args2, long unsigned int ..._Indexes2> std::pair<_T1, _T2>::pair(std::tuple<_Args1 ...>&, std::tuple<_Args2 ...>&, std::_Index_tuple<_Indexes1 ...>, std::_Index_tuple<_Indexes2 ...>)
         pair(tuple<_Args1...>&, tuple<_Args2...>&,
         ^

Added: 
    

Modified: 
    mlir/include/mlir/TableGen/Pattern.h
    mlir/lib/TableGen/Pattern.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 4fc2ae762a667..a5759e358f695 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -21,8 +21,6 @@
 #include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringSet.h"
 
-#include <unordered_map>
-
 namespace llvm {
 class DagInit;
 class Init;
@@ -230,9 +228,6 @@ class SymbolInfoMap {
     // value bound by this symbol.
     std::string getVarDecl(StringRef name) const;
 
-    // Returns a variable name for the symbol named as `name`.
-    std::string getVarName(StringRef name) const;
-
   private:
     // Allow SymbolInfoMap to access private methods.
     friend class SymbolInfoMap;
@@ -290,12 +285,9 @@ class SymbolInfoMap {
     Kind kind;          // The kind of the bound entity
     // The argument index (for `Attr` and `Operand` only)
     Optional<int> argIndex;
-    // Alternative name for the symbol. It is used in case the name
-    // is not unique. Applicable for `Operand` only.
-    Optional<std::string> alternativeName;
   };
 
-  using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
+  using BaseT = llvm::StringMap<SymbolInfo>;
 
   // Iterators for accessing all symbols.
   using iterator = BaseT::iterator;
@@ -308,7 +300,7 @@ class SymbolInfoMap {
   const_iterator end() const { return symbolInfoMap.end(); }
 
   // Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
-  // Returns false if `symbol` is already bound and symbols are not operands.
+  // Returns false if `symbol` is already bound.
   bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
 
   // Binds the given `symbol` to the results the given `op`. Returns false if
@@ -325,18 +317,6 @@ class SymbolInfoMap {
   // Returns an iterator to the information of the given symbol named as `key`.
   const_iterator find(StringRef key) const;
 
-  // Returns an iterator to the information of the given symbol named as `key`,
-  // with index `argIndex` for operator `op`.
-  const_iterator findBoundSymbol(StringRef key, const Operator &op,
-                                 int argIndex) const;
-
-  // Returns the bounds of a range that includes all the elements which
-  // bind to the `key`.
-  std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
-
-  // Returns number of times symbol named as `key` was used.
-  int count(StringRef key) const;
-
   // Returns the number of static values of the given `symbol` corresponds to.
   // A static value is an operand/result declared in ODS. Normally a symbol only
   // represents one static value, but symbols bound to op results can represent
@@ -358,9 +338,6 @@ class SymbolInfoMap {
   std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
                              const char *separator = ", ") const;
 
-  // Assign alternative unique names to Operands that have equal names.
-  void assignUniqueAlternativeNames();
-
   // Splits the given `symbol` into a value pack name and an index. Returns the
   // value pack name and writes the index to `index` on success. Returns
   // `symbol` itself if it does not contain an index.
@@ -370,7 +347,7 @@ class SymbolInfoMap {
   static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
 
 private:
-  BaseT symbolInfoMap;
+  llvm::StringMap<SymbolInfo> symbolInfoMap;
 
   // Pattern instantiation location. This is intended to be used as parameter
   // to PrintFatalError() to report errors.

diff  --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 5170f07870abb..cfa3da2c417a4 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -208,10 +208,6 @@ int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
   llvm_unreachable("unknown kind");
 }
 
-std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
-  return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
-}
-
 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
   switch (kind) {
@@ -223,9 +219,8 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
   case Kind::Operand: {
     // Use operand range for captured operands (to support potential variadic
     // operands).
-    return std::string(
-        formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
-                getVarName(name)));
+    return std::string(formatv(
+        "::mlir::Operation::operand_range {0}(op0->getOperands());\n", name));
   }
   case Kind::Value: {
     return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name));
@@ -364,34 +359,16 @@ bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
                      ? SymbolInfo::getAttr(&op, argIndex)
                      : SymbolInfo::getOperand(&op, argIndex);
 
-  std::string key = symbol.str();
-  if (auto numberOfEntries = symbolInfoMap.count(key)) {
-    // Only non unique name for the operand is supported.
-    if (symInfo.kind != SymbolInfo::Kind::Operand) {
-      return false;
-    }
-
-    // Cannot add new operand if there is already non operand with the same
-    // name.
-    if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
-      return false;
-    }
-  }
-
-  symbolInfoMap.emplace(key, symInfo);
-  return true;
+  return symbolInfoMap.insert({symbol, symInfo}).second;
 }
 
 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
   StringRef name = getValuePackName(symbol);
-  auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
-
-  return symbolInfoMap.count(inserted->first) == 1;
+  return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
 }
 
 bool SymbolInfoMap::bindValue(StringRef symbol) {
-  auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getValue());
-  return symbolInfoMap.count(inserted->first) == 1;
+  return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
 }
 
 bool SymbolInfoMap::contains(StringRef symbol) const {
@@ -399,38 +376,10 @@ bool SymbolInfoMap::contains(StringRef symbol) const {
 }
 
 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
-  std::string name = getValuePackName(key).str();
-
+  StringRef name = getValuePackName(key);
   return symbolInfoMap.find(name);
 }
 
-SymbolInfoMap::const_iterator
-SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
-                               int argIndex) const {
-  std::string name = getValuePackName(key).str();
-  auto range = symbolInfoMap.equal_range(name);
-
-  for (auto it = range.first; it != range.second; ++it) {
-    if (it->second.op == &op && it->second.argIndex == argIndex) {
-      return it;
-    }
-  }
-
-  return symbolInfoMap.end();
-}
-
-std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
-SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
-  std::string name = getValuePackName(key).str();
-
-  return symbolInfoMap.equal_range(name);
-}
-
-int SymbolInfoMap::count(StringRef key) const {
-  std::string name = getValuePackName(key).str();
-  return symbolInfoMap.count(name);
-}
-
 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
   StringRef name = getValuePackName(symbol);
   if (name != symbol) {
@@ -439,7 +388,7 @@ int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
     return 1;
   }
   // Otherwise, find how many it represents by querying the symbol's info.
-  return find(name)->second.getStaticValueCount();
+  return find(name)->getValue().getStaticValueCount();
 }
 
 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
@@ -448,13 +397,13 @@ std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
   int index = -1;
   StringRef name = getValuePackName(symbol, &index);
 
-  auto it = symbolInfoMap.find(name.str());
+  auto it = symbolInfoMap.find(name);
   if (it == symbolInfoMap.end()) {
     auto error = formatv("referencing unbound symbol '{0}'", symbol);
     PrintFatalError(loc, error);
   }
 
-  return it->second.getValueAndRangeUse(name, index, fmt, separator);
+  return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
 }
 
 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
@@ -462,44 +411,13 @@ std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
   int index = -1;
   StringRef name = getValuePackName(symbol, &index);
 
-  auto it = symbolInfoMap.find(name.str());
+  auto it = symbolInfoMap.find(name);
   if (it == symbolInfoMap.end()) {
     auto error = formatv("referencing unbound symbol '{0}'", symbol);
     PrintFatalError(loc, error);
   }
 
-  return it->second.getAllRangeUse(name, index, fmt, separator);
-}
-
-void SymbolInfoMap::assignUniqueAlternativeNames() {
-  llvm::StringSet<> usedNames;
-
-  for (auto symbolInfoIt = symbolInfoMap.begin();
-       symbolInfoIt != symbolInfoMap.end();) {
-    auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
-    auto startRange = range.first;
-    auto endRange = range.second;
-
-    auto operandName = symbolInfoIt->first;
-    int startSearchIndex = 0;
-    for (++startRange; startRange != endRange; ++startRange) {
-      // Current operand name is not unique, find a unique one
-      // and set the alternative name.
-      for (int i = startSearchIndex;; ++i) {
-        std::string alternativeName = operandName + std::to_string(i);
-        if (!usedNames.contains(alternativeName) &&
-            symbolInfoMap.count(alternativeName) == 0) {
-          usedNames.insert(alternativeName);
-          startRange->second.alternativeName = alternativeName;
-          startSearchIndex = i + 1;
-
-          break;
-        }
-      }
-    }
-
-    symbolInfoIt = endRange;
-  }
+  return it->getValue().getAllRangeUse(name, index, fmt, separator);
 }
 
 //===----------------------------------------------------------------------===//
@@ -527,10 +445,6 @@ void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
-
-  LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
-  infoMap.assignUniqueAlternativeNames();
-  LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
 }
 
 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index aef39a9e19fec..d36d7bd58ea8d 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -619,32 +619,6 @@ def OpM : TEST_Op<"op_m"> {
   let results = (outs I32);
 }
 
-def OpN : TEST_Op<"op_n"> {
-  let arguments = (ins I32, I32);
-  let results = (outs I32);
-}
-
-def OpO : TEST_Op<"op_o"> {
-  let arguments = (ins I32);
-  let results = (outs I32);
-}
-
-def OpP : TEST_Op<"op_p"> {
-  let arguments = (ins I32, I32, I32, I32, I32, I32);
-  let results = (outs I32);
-}
-
-// Test same operand name enforces equality condition check.
-def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
-
-// Test when equality is enforced at 
diff erent depth.
-def TestNestedOpEqualArgsPattern :
-  Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
-
-// Test multiple equal arguments check enforced.
-def TestMultipleEqualArgsPattern :
-  Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
-
 // Test for memrefs normalization of an op with normalizable memrefs.
 def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
   let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 5986be6240f9b..0f2f434c928f9 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -111,64 +111,6 @@ func @verifyManyArgs(%arg: i32) {
   return
 }
 
-// CHECK-LABEL: verifyEqualArgs
-func @verifyEqualArgs(%arg0: i32, %arg1: i32) {
-  // def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
-
-  // CHECK: "test.op_o"(%arg0) : (i32) -> i32
-  "test.op_n"(%arg0, %arg0) : (i32, i32) -> (i32)
-
-  // CHECK: "test.op_n"(%arg0, %arg1) : (i32, i32) -> i32
-  "test.op_n"(%arg0, %arg1) : (i32, i32) -> (i32)
-
-  return
-}
-
-// CHECK-LABEL: verifyNestedOpEqualArgs
-func @verifyNestedOpEqualArgs(
-  %arg0: i32, %arg1: i32, %arg2 : i32, %arg3 : i32, %arg4 : i32, %arg5 : i32) {
-  // def TestNestedOpEqualArgsPattern :
-  //   Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
-
-  // CHECK: %arg1
-  %0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
-    : (i32, i32, i32, i32, i32, i32) -> (i32)
-  %1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)
-
-  // CHECK: test.op_p
-  // CHECK: test.op_n
-  %2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
-    : (i32, i32, i32, i32, i32, i32) -> (i32)
-  %3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)
-
-  return
-}
-
-// CHECK-LABEL: verifyMultipleEqualArgs
-func @verifyMultipleEqualArgs(
-  %arg0: i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) {
-  // def TestMultipleEqualArgsPattern :
-  //   Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
-
-  // CHECK: "test.op_n"(%arg2, %arg1) : (i32, i32) -> i32
-  "test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg1, %arg2) :
-    (i32, i32, i32, i32 , i32, i32) -> i32
-
-  // CHECK: test.op_p
-  "test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg0, %arg2) :
-    (i32, i32, i32, i32 , i32, i32) -> i32
-
-  // CHECK: test.op_p
-  "test.op_p"(%arg0, %arg1, %arg1, %arg0, %arg1, %arg2) :
-    (i32, i32, i32, i32 , i32, i32) -> i32
-
-   // CHECK: test.op_p
-  "test.op_p"(%arg0, %arg1, %arg2, %arg2, %arg3, %arg4) :
-    (i32, i32, i32, i32 , i32, i32) -> i32
-
-  return
-}
-
 //===----------------------------------------------------------------------===//
 // Test Symbol Binding
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 7bff3e3b40b62..ff6138f73914e 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -89,11 +89,6 @@ class PatternEmitter {
   void emitMatchCheck(int depth, const FmtObjectBase &matchFmt,
                       const llvm::formatv_object_base &failureFmt);
 
-  // Emits C++ for checking a match with a corresponding match failure
-  // diagnostics.
-  void emitMatchCheck(int depth, const std::string &matchStr,
-                      const std::string &failureStr);
-
   //===--------------------------------------------------------------------===//
   // Rewrite utilities
   //===--------------------------------------------------------------------===//
@@ -332,9 +327,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
         op.arg_begin(), op.arg_begin() + argIndex,
         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
 
-    auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
-    os << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
-                  res->second.getVarName(name), depth, argIndex - numPrevAttrs);
+    os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth,
+                  argIndex - numPrevAttrs);
   }
 }
 
@@ -399,15 +393,10 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
 void PatternEmitter::emitMatchCheck(
     int depth, const FmtObjectBase &matchFmt,
     const llvm::formatv_object_base &failureFmt) {
-  emitMatchCheck(depth, matchFmt.str(), failureFmt.str());
-}
-
-void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr,
-                                    const std::string &failureStr) {
-  os << "if (!(" << matchStr << "))";
+  os << "if (!(" << matchFmt.str() << "))";
   os.scope("{\n", "\n}\n").os
       << "return rewriter.notifyMatchFailure(op" << depth
-      << ", [&](::mlir::Diagnostic &diag) {\n  diag << " << failureStr
+      << ", [&](::mlir::Diagnostic &diag) {\n  diag << " << failureFmt.str()
       << ";\n});";
 }
 
@@ -456,30 +445,6 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
                              constraint.getDescription()));
     }
   }
-
-  // Some of the operands could be bound to the same symbol name, we need
-  // to enforce equality constraint on those.
-  // TODO: we should be able to emit equality checks early
-  // and short circuit unnecessary work if vars are not equal.
-  for (auto symbolInfoIt = symbolInfoMap.begin();
-       symbolInfoIt != symbolInfoMap.end();) {
-    auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
-    auto startRange = range.first;
-    auto endRange = range.second;
-
-    auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
-    for (++startRange; startRange != endRange; ++startRange) {
-      auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
-      emitMatchCheck(
-          depth,
-          formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
-          formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
-                  secondOperand));
-    }
-
-    symbolInfoIt = endRange;
-  }
-
   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
 }
 
@@ -553,9 +518,8 @@ void PatternEmitter::emit(StringRef rewriteName) {
       // Create local variables for storing the arguments and results bound
       // to symbols.
       for (const auto &symbolInfoPair : symbolInfoMap) {
-        const auto &symbol = symbolInfoPair.first;
-        const auto &info = symbolInfoPair.second;
-
+        StringRef symbol = symbolInfoPair.getKey();
+        auto &info = symbolInfoPair.getValue();
         os << info.getVarDecl(symbol);
       }
       // TODO: capture ops with consistent numbering so that it can be
@@ -1129,7 +1093,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
       os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
                     range);
     } else {
-      os << formatv("tblgen_values.push_back(");
+      os << formatv("tblgen_values.push_back(", varName);
       if (node.isNestedDagArg(argIndex)) {
         os << symbolInfoMap.getValueAndRangeUse(
             childNodeNames.lookup(argIndex));


        


More information about the Mlir-commits mailing list