[Mlir-commits] [mlir] 0b793c4 - Revert "[DDR] Introduce implicit equality check for the source pattern operands with the same name."
Mehdi Amini
llvmlistbot at llvm.org
Tue Oct 13 17:38:28 PDT 2020
Author: Mehdi Amini
Date: 2020-10-14T00:37:10Z
New Revision: 0b793c4be0eee90a22b7a150187f5f7cf744c120
URL: https://github.com/llvm/llvm-project/commit/0b793c4be0eee90a22b7a150187f5f7cf744c120
DIFF: https://github.com/llvm/llvm-project/commit/0b793c4be0eee90a22b7a150187f5f7cf744c120.diff
LOG: Revert "[DDR] Introduce implicit equality check for the source pattern operands with the same name."
This reverts commit 7271c1bcb96051bcd227d3fa6071a620fe238850.
This broke the gcc-5 build:
/usr/include/c++/5/ext/new_allocator.h:120:4: error: no matching function for call to 'std::pair<const std::__cxx11::basic_string<char>, mlir::tblgen::SymbolInfoMap::SymbolInfo>::pair(llvm::StringRef&, mlir::tblgen::SymbolInfoMap::SymbolInfo)'
{ ::new((void *)__p) _Up(std::forward<_Args>(__args)...); }
^
In file included from /usr/include/c++/5/utility:70:0,
from llvm/include/llvm/Support/type_traits.h:18,
from llvm/include/llvm/Support/Casting.h:18,
from mlir/include/mlir/Support/LLVM.h:24,
from mlir/include/mlir/TableGen/Pattern.h:17,
from mlir/lib/TableGen/Pattern.cpp:14:
/usr/include/c++/5/bits/stl_pair.h:206:9: note: candidate: template<class ... _Args1, long unsigned int ..._Indexes1, class ... _Args2, long unsigned int ..._Indexes2> std::pair<_T1, _T2>::pair(std::tuple<_Args1 ...>&, std::tuple<_Args2 ...>&, std::_Index_tuple<_Indexes1 ...>, std::_Index_tuple<_Indexes2 ...>)
pair(tuple<_Args1...>&, tuple<_Args2...>&,
^
Added:
Modified:
mlir/include/mlir/TableGen/Pattern.h
mlir/lib/TableGen/Pattern.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 4fc2ae762a667..a5759e358f695 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -21,8 +21,6 @@
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
-#include <unordered_map>
-
namespace llvm {
class DagInit;
class Init;
@@ -230,9 +228,6 @@ class SymbolInfoMap {
// value bound by this symbol.
std::string getVarDecl(StringRef name) const;
- // Returns a variable name for the symbol named as `name`.
- std::string getVarName(StringRef name) const;
-
private:
// Allow SymbolInfoMap to access private methods.
friend class SymbolInfoMap;
@@ -290,12 +285,9 @@ class SymbolInfoMap {
Kind kind; // The kind of the bound entity
// The argument index (for `Attr` and `Operand` only)
Optional<int> argIndex;
- // Alternative name for the symbol. It is used in case the name
- // is not unique. Applicable for `Operand` only.
- Optional<std::string> alternativeName;
};
- using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
+ using BaseT = llvm::StringMap<SymbolInfo>;
// Iterators for accessing all symbols.
using iterator = BaseT::iterator;
@@ -308,7 +300,7 @@ class SymbolInfoMap {
const_iterator end() const { return symbolInfoMap.end(); }
// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
- // Returns false if `symbol` is already bound and symbols are not operands.
+ // Returns false if `symbol` is already bound.
bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
// Binds the given `symbol` to the results the given `op`. Returns false if
@@ -325,18 +317,6 @@ class SymbolInfoMap {
// Returns an iterator to the information of the given symbol named as `key`.
const_iterator find(StringRef key) const;
- // Returns an iterator to the information of the given symbol named as `key`,
- // with index `argIndex` for operator `op`.
- const_iterator findBoundSymbol(StringRef key, const Operator &op,
- int argIndex) const;
-
- // Returns the bounds of a range that includes all the elements which
- // bind to the `key`.
- std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
-
- // Returns number of times symbol named as `key` was used.
- int count(StringRef key) const;
-
// Returns the number of static values of the given `symbol` corresponds to.
// A static value is an operand/result declared in ODS. Normally a symbol only
// represents one static value, but symbols bound to op results can represent
@@ -358,9 +338,6 @@ class SymbolInfoMap {
std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
const char *separator = ", ") const;
- // Assign alternative unique names to Operands that have equal names.
- void assignUniqueAlternativeNames();
-
// Splits the given `symbol` into a value pack name and an index. Returns the
// value pack name and writes the index to `index` on success. Returns
// `symbol` itself if it does not contain an index.
@@ -370,7 +347,7 @@ class SymbolInfoMap {
static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
private:
- BaseT symbolInfoMap;
+ llvm::StringMap<SymbolInfo> symbolInfoMap;
// Pattern instantiation location. This is intended to be used as parameter
// to PrintFatalError() to report errors.
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 5170f07870abb..cfa3da2c417a4 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -208,10 +208,6 @@ int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
llvm_unreachable("unknown kind");
}
-std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
- return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
-}
-
std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
switch (kind) {
@@ -223,9 +219,8 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
case Kind::Operand: {
// Use operand range for captured operands (to support potential variadic
// operands).
- return std::string(
- formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
- getVarName(name)));
+ return std::string(formatv(
+ "::mlir::Operation::operand_range {0}(op0->getOperands());\n", name));
}
case Kind::Value: {
return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name));
@@ -364,34 +359,16 @@ bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
? SymbolInfo::getAttr(&op, argIndex)
: SymbolInfo::getOperand(&op, argIndex);
- std::string key = symbol.str();
- if (auto numberOfEntries = symbolInfoMap.count(key)) {
- // Only non unique name for the operand is supported.
- if (symInfo.kind != SymbolInfo::Kind::Operand) {
- return false;
- }
-
- // Cannot add new operand if there is already non operand with the same
- // name.
- if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
- return false;
- }
- }
-
- symbolInfoMap.emplace(key, symInfo);
- return true;
+ return symbolInfoMap.insert({symbol, symInfo}).second;
}
bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
StringRef name = getValuePackName(symbol);
- auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
-
- return symbolInfoMap.count(inserted->first) == 1;
+ return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
}
bool SymbolInfoMap::bindValue(StringRef symbol) {
- auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getValue());
- return symbolInfoMap.count(inserted->first) == 1;
+ return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
}
bool SymbolInfoMap::contains(StringRef symbol) const {
@@ -399,38 +376,10 @@ bool SymbolInfoMap::contains(StringRef symbol) const {
}
SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
- std::string name = getValuePackName(key).str();
-
+ StringRef name = getValuePackName(key);
return symbolInfoMap.find(name);
}
-SymbolInfoMap::const_iterator
-SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
- int argIndex) const {
- std::string name = getValuePackName(key).str();
- auto range = symbolInfoMap.equal_range(name);
-
- for (auto it = range.first; it != range.second; ++it) {
- if (it->second.op == &op && it->second.argIndex == argIndex) {
- return it;
- }
- }
-
- return symbolInfoMap.end();
-}
-
-std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
-SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
- std::string name = getValuePackName(key).str();
-
- return symbolInfoMap.equal_range(name);
-}
-
-int SymbolInfoMap::count(StringRef key) const {
- std::string name = getValuePackName(key).str();
- return symbolInfoMap.count(name);
-}
-
int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
StringRef name = getValuePackName(symbol);
if (name != symbol) {
@@ -439,7 +388,7 @@ int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
return 1;
}
// Otherwise, find how many it represents by querying the symbol's info.
- return find(name)->second.getStaticValueCount();
+ return find(name)->getValue().getStaticValueCount();
}
std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
@@ -448,13 +397,13 @@ std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
int index = -1;
StringRef name = getValuePackName(symbol, &index);
- auto it = symbolInfoMap.find(name.str());
+ auto it = symbolInfoMap.find(name);
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}
- return it->second.getValueAndRangeUse(name, index, fmt, separator);
+ return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
}
std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
@@ -462,44 +411,13 @@ std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
int index = -1;
StringRef name = getValuePackName(symbol, &index);
- auto it = symbolInfoMap.find(name.str());
+ auto it = symbolInfoMap.find(name);
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}
- return it->second.getAllRangeUse(name, index, fmt, separator);
-}
-
-void SymbolInfoMap::assignUniqueAlternativeNames() {
- llvm::StringSet<> usedNames;
-
- for (auto symbolInfoIt = symbolInfoMap.begin();
- symbolInfoIt != symbolInfoMap.end();) {
- auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
- auto startRange = range.first;
- auto endRange = range.second;
-
- auto operandName = symbolInfoIt->first;
- int startSearchIndex = 0;
- for (++startRange; startRange != endRange; ++startRange) {
- // Current operand name is not unique, find a unique one
- // and set the alternative name.
- for (int i = startSearchIndex;; ++i) {
- std::string alternativeName = operandName + std::to_string(i);
- if (!usedNames.contains(alternativeName) &&
- symbolInfoMap.count(alternativeName) == 0) {
- usedNames.insert(alternativeName);
- startRange->second.alternativeName = alternativeName;
- startSearchIndex = i + 1;
-
- break;
- }
- }
- }
-
- symbolInfoIt = endRange;
- }
+ return it->getValue().getAllRangeUse(name, index, fmt, separator);
}
//===----------------------------------------------------------------------===//
@@ -527,10 +445,6 @@ void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
-
- LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
- infoMap.assignUniqueAlternativeNames();
- LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
}
void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index aef39a9e19fec..d36d7bd58ea8d 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -619,32 +619,6 @@ def OpM : TEST_Op<"op_m"> {
let results = (outs I32);
}
-def OpN : TEST_Op<"op_n"> {
- let arguments = (ins I32, I32);
- let results = (outs I32);
-}
-
-def OpO : TEST_Op<"op_o"> {
- let arguments = (ins I32);
- let results = (outs I32);
-}
-
-def OpP : TEST_Op<"op_p"> {
- let arguments = (ins I32, I32, I32, I32, I32, I32);
- let results = (outs I32);
-}
-
-// Test same operand name enforces equality condition check.
-def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
-
-// Test when equality is enforced at
diff erent depth.
-def TestNestedOpEqualArgsPattern :
- Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
-
-// Test multiple equal arguments check enforced.
-def TestMultipleEqualArgsPattern :
- Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
-
// Test for memrefs normalization of an op with normalizable memrefs.
def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 5986be6240f9b..0f2f434c928f9 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -111,64 +111,6 @@ func @verifyManyArgs(%arg: i32) {
return
}
-// CHECK-LABEL: verifyEqualArgs
-func @verifyEqualArgs(%arg0: i32, %arg1: i32) {
- // def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
-
- // CHECK: "test.op_o"(%arg0) : (i32) -> i32
- "test.op_n"(%arg0, %arg0) : (i32, i32) -> (i32)
-
- // CHECK: "test.op_n"(%arg0, %arg1) : (i32, i32) -> i32
- "test.op_n"(%arg0, %arg1) : (i32, i32) -> (i32)
-
- return
-}
-
-// CHECK-LABEL: verifyNestedOpEqualArgs
-func @verifyNestedOpEqualArgs(
- %arg0: i32, %arg1: i32, %arg2 : i32, %arg3 : i32, %arg4 : i32, %arg5 : i32) {
- // def TestNestedOpEqualArgsPattern :
- // Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
-
- // CHECK: %arg1
- %0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
- : (i32, i32, i32, i32, i32, i32) -> (i32)
- %1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)
-
- // CHECK: test.op_p
- // CHECK: test.op_n
- %2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
- : (i32, i32, i32, i32, i32, i32) -> (i32)
- %3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)
-
- return
-}
-
-// CHECK-LABEL: verifyMultipleEqualArgs
-func @verifyMultipleEqualArgs(
- %arg0: i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) {
- // def TestMultipleEqualArgsPattern :
- // Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
-
- // CHECK: "test.op_n"(%arg2, %arg1) : (i32, i32) -> i32
- "test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg1, %arg2) :
- (i32, i32, i32, i32 , i32, i32) -> i32
-
- // CHECK: test.op_p
- "test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg0, %arg2) :
- (i32, i32, i32, i32 , i32, i32) -> i32
-
- // CHECK: test.op_p
- "test.op_p"(%arg0, %arg1, %arg1, %arg0, %arg1, %arg2) :
- (i32, i32, i32, i32 , i32, i32) -> i32
-
- // CHECK: test.op_p
- "test.op_p"(%arg0, %arg1, %arg2, %arg2, %arg3, %arg4) :
- (i32, i32, i32, i32 , i32, i32) -> i32
-
- return
-}
-
//===----------------------------------------------------------------------===//
// Test Symbol Binding
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 7bff3e3b40b62..ff6138f73914e 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -89,11 +89,6 @@ class PatternEmitter {
void emitMatchCheck(int depth, const FmtObjectBase &matchFmt,
const llvm::formatv_object_base &failureFmt);
- // Emits C++ for checking a match with a corresponding match failure
- // diagnostics.
- void emitMatchCheck(int depth, const std::string &matchStr,
- const std::string &failureStr);
-
//===--------------------------------------------------------------------===//
// Rewrite utilities
//===--------------------------------------------------------------------===//
@@ -332,9 +327,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
op.arg_begin(), op.arg_begin() + argIndex,
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
- auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
- os << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
- res->second.getVarName(name), depth, argIndex - numPrevAttrs);
+ os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth,
+ argIndex - numPrevAttrs);
}
}
@@ -399,15 +393,10 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
void PatternEmitter::emitMatchCheck(
int depth, const FmtObjectBase &matchFmt,
const llvm::formatv_object_base &failureFmt) {
- emitMatchCheck(depth, matchFmt.str(), failureFmt.str());
-}
-
-void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr,
- const std::string &failureStr) {
- os << "if (!(" << matchStr << "))";
+ os << "if (!(" << matchFmt.str() << "))";
os.scope("{\n", "\n}\n").os
<< "return rewriter.notifyMatchFailure(op" << depth
- << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureStr
+ << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureFmt.str()
<< ";\n});";
}
@@ -456,30 +445,6 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
constraint.getDescription()));
}
}
-
- // Some of the operands could be bound to the same symbol name, we need
- // to enforce equality constraint on those.
- // TODO: we should be able to emit equality checks early
- // and short circuit unnecessary work if vars are not equal.
- for (auto symbolInfoIt = symbolInfoMap.begin();
- symbolInfoIt != symbolInfoMap.end();) {
- auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
- auto startRange = range.first;
- auto endRange = range.second;
-
- auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
- for (++startRange; startRange != endRange; ++startRange) {
- auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
- emitMatchCheck(
- depth,
- formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
- formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
- secondOperand));
- }
-
- symbolInfoIt = endRange;
- }
-
LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
}
@@ -553,9 +518,8 @@ void PatternEmitter::emit(StringRef rewriteName) {
// Create local variables for storing the arguments and results bound
// to symbols.
for (const auto &symbolInfoPair : symbolInfoMap) {
- const auto &symbol = symbolInfoPair.first;
- const auto &info = symbolInfoPair.second;
-
+ StringRef symbol = symbolInfoPair.getKey();
+ auto &info = symbolInfoPair.getValue();
os << info.getVarDecl(symbol);
}
// TODO: capture ops with consistent numbering so that it can be
@@ -1129,7 +1093,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
range);
} else {
- os << formatv("tblgen_values.push_back(");
+ os << formatv("tblgen_values.push_back(", varName);
if (node.isNestedDagArg(argIndex)) {
os << symbolInfoMap.getValueAndRangeUse(
childNodeNames.lookup(argIndex));
More information about the Mlir-commits
mailing list