[Mlir-commits] [mlir] b4001ae - [mlir-tblgen] Fix failed matching when binds same operand of an op in different depth
Chia-hung Duan
llvmlistbot at llvm.org
Tue Jul 20 00:45:25 PDT 2021
Author: Chia-hung Duan
Date: 2021-07-20T15:43:09+08:00
New Revision: b4001ae8851f47406f8187b72b0253b21bf1da4c
URL: https://github.com/llvm/llvm-project/commit/b4001ae8851f47406f8187b72b0253b21bf1da4c
DIFF: https://github.com/llvm/llvm-project/commit/b4001ae8851f47406f8187b72b0253b21bf1da4c.diff
LOG: [mlir-tblgen] Fix failed matching when binds same operand of an op in different depth
For example, we will generate incorrect code for the pattern,
def : Pat<((FooOp (FooOp, $a, $b), $b)), (...)>;
We didn't allow $b to be bond twice with same operand of same op.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D105677
Added:
Modified:
mlir/include/mlir/TableGen/Pattern.h
mlir/lib/TableGen/Pattern.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 98c5d9b18f5dd..4b397c7c02bf6 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -184,6 +184,9 @@ class DagNode {
void print(raw_ostream &os) const;
private:
+ friend class SymbolInfoMap;
+ const void *getAsOpaquePointer() const { return node; }
+
const llvm::DagInit *node; // nullptr means null DagNode
};
@@ -237,6 +240,10 @@ class SymbolInfoMap {
// Allow SymbolInfoMap to access private methods.
friend class SymbolInfoMap;
+ // DagNode and DagLeaf are accessed by value which means it can't be used as
+ // identifier here. Use an opaque pointer type instead.
+ using DagAndIndex = std::pair<const void *, int>;
+
// What kind of entity this symbol represents:
// * Attr: op attribute
// * Operand: op operand
@@ -244,19 +251,21 @@ class SymbolInfoMap {
// * Value: a value not attached to an op (e.g., from NativeCodeCall)
enum class Kind : uint8_t { Attr, Operand, Result, Value };
- // Creates a SymbolInfo instance. `index` is only used for `Attr` and
- // `Operand` so should be negative for `Result` and `Value` kind.
- SymbolInfo(const Operator *op, Kind kind, Optional<int> index);
+ // Creates a SymbolInfo instance. `dagAndIndex` is only used for `Attr` and
+ // `Operand` so should be llvm::None for `Result` and `Value` kind.
+ SymbolInfo(const Operator *op, Kind kind,
+ Optional<DagAndIndex> dagAndIndex);
// Static methods for creating SymbolInfo.
static SymbolInfo getAttr(const Operator *op, int index) {
- return SymbolInfo(op, Kind::Attr, index);
+ return SymbolInfo(op, Kind::Attr, DagAndIndex(nullptr, index));
}
static SymbolInfo getAttr() {
return SymbolInfo(nullptr, Kind::Attr, llvm::None);
}
- static SymbolInfo getOperand(const Operator *op, int index) {
- return SymbolInfo(op, Kind::Operand, index);
+ static SymbolInfo getOperand(DagNode node, const Operator *op, int index) {
+ return SymbolInfo(op, Kind::Operand,
+ DagAndIndex(node.getAsOpaquePointer(), index));
}
static SymbolInfo getResult(const Operator *op) {
return SymbolInfo(op, Kind::Result, llvm::None);
@@ -291,8 +300,11 @@ class SymbolInfoMap {
const Operator *op; // The op where the bound entity belongs
Kind kind; // The kind of the bound entity
- // The argument index (for `Attr` and `Operand` only)
- Optional<int> argIndex;
+ // The pair of DagNode pointer and argument index (for `Attr` and `Operand`
+ // only). Note that operands may be bound to the same symbol, use the
+ // DagNode and index to distinguish them. For `Attr`, the Dag part will be
+ // nullptr.
+ Optional<DagAndIndex> dagAndIndex;
// Alternative name for the symbol. It is used in case the name
// is not unique. Applicable for `Operand` only.
Optional<std::string> alternativeName;
@@ -312,7 +324,8 @@ class SymbolInfoMap {
// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
// Returns false if `symbol` is already bound and symbols are not operands.
- bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
+ bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op,
+ int argIndex);
// Binds the given `symbol` to the results the given `op`. Returns false if
// `symbol` is already bound.
@@ -334,8 +347,8 @@ class SymbolInfoMap {
// Returns an iterator to the information of the given symbol named as `key`,
// with index `argIndex` for operator `op`.
- const_iterator findBoundSymbol(StringRef key, const Operator &op,
- int argIndex) const;
+ const_iterator findBoundSymbol(StringRef key, DagNode node,
+ const Operator &op, int argIndex) const;
// Returns the bounds of a range that includes all the elements which
// bind to the `key`.
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index d3bd6f7662bff..a9b03519fb540 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -193,8 +193,8 @@ StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
}
SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
- Optional<int> index)
- : op(op), kind(kind), argIndex(index) {}
+ Optional<DagAndIndex> dagAndIndex)
+ : op(op), kind(kind), dagAndIndex(dagAndIndex) {}
int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
switch (kind) {
@@ -217,8 +217,9 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
switch (kind) {
case Kind::Attr: {
if (op) {
- auto type =
- op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
+ auto type = op->getArg((*dagAndIndex).second)
+ .get<NamedAttribute *>()
+ ->attr.getStorageType();
return std::string(formatv("{0} {1};\n", type, name));
}
// TODO(suderman): Use a more exact type when available.
@@ -254,7 +255,8 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
}
case Kind::Operand: {
assert(index < 0);
- auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
+ auto *operand =
+ op->getArg((*dagAndIndex).second).get<NamedTypeConstraint *>();
// If this operand is variadic, then return a range. Otherwise, return the
// value itself.
if (operand->isVariableLength()) {
@@ -355,8 +357,8 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
llvm_unreachable("unknown kind");
}
-bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
- int argIndex) {
+bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
+ const Operator &op, int argIndex) {
StringRef name = getValuePackName(symbol);
if (name != symbol) {
auto error = formatv(
@@ -366,7 +368,7 @@ bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
? SymbolInfo::getAttr(&op, argIndex)
- : SymbolInfo::getOperand(&op, argIndex);
+ : SymbolInfo::getOperand(node, &op, argIndex);
std::string key = symbol.str();
if (symbolInfoMap.count(key)) {
@@ -414,13 +416,15 @@ SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
}
SymbolInfoMap::const_iterator
-SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
+SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
int argIndex) const {
std::string name = getValuePackName(key).str();
auto range = symbolInfoMap.equal_range(name);
+ const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex);
+
for (auto it = range.first; it != range.second; ++it) {
- if (it->second.op == &op && it->second.argIndex == argIndex) {
+ if (it->second.dagAndIndex == symbolInfo.dagAndIndex) {
return it;
}
}
@@ -722,7 +726,8 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
if (!treeArgName.empty() && treeArgName != "_") {
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
<< treeArgName << '\n');
- verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName);
+ verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i),
+ treeArgName);
}
}
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 16e141ede1736..c94ca64ba6958 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -716,6 +716,13 @@ def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
def TestNestedOpEqualArgsPattern :
Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
+// Test when equality is enforced on same op and same operand but at
diff erent
+// depth. We only bound one of the $x to the second operand of outer OpN and
+// left another be the default value (which is the value of first operand of
+// outer OpN). As a result, it ended up comparing wrong values in some cases.
+def TestNestedSameOpAndSameArgEqualityPattern :
+ Pat<(OpN (OpN $_, $x), $x), (replaceWithValue $x)>;
+
// Test multiple equal arguments check enforced.
def TestMultipleEqualArgsPattern :
Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index affc3d7a93968..69140dfb3fd47 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -158,6 +158,17 @@ func @verifyNestedOpEqualArgs(
return
}
+// CHECK-LABEL: verifyNestedSameOpAndSameArgEquality
+func @verifyNestedSameOpAndSameArgEquality(%arg0: i32, %arg1: i32) -> i32 {
+ // def TestNestedSameOpAndSameArgEqualityPattern:
+ // Pat<(OpN (OpN $_, $x), $x), (replaceWithValue $x)>;
+
+ %0 = "test.op_n"(%arg1, %arg0) : (i32, i32) -> (i32)
+ %1 = "test.op_n"(%0, %arg0) : (i32, i32) -> (i32)
+ // CHECK: return %arg0 : i32
+ return %1 : i32
+}
+
// CHECK-LABEL: verifyMultipleEqualArgs
func @verifyMultipleEqualArgs(
%arg0: i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) {
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 611bc5c1c05ea..9d3e4a93b53d6 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -454,7 +454,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
op.arg_begin(), op.arg_begin() + argIndex,
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
- auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
+ auto res = symbolInfoMap.findBoundSymbol(name, tree, op, argIndex);
os << formatv("{0} = {1}.getODSOperands({2});\n",
res->second.getVarName(name), opName,
argIndex - numPrevAttrs);
More information about the Mlir-commits
mailing list