[Mlir-commits] [mlir] 2bf423b - [mlir] RewriterGen NativeCodeCall matcher with ConstantOp matcher
Rob Suderman
llvmlistbot at llvm.org
Thu Oct 15 16:34:10 PDT 2020
Author: Rob Suderman
Date: 2020-10-15T16:32:20-07:00
New Revision: 2bf423b0218c9583e3a372950a34facbf93e63d3
URL: https://github.com/llvm/llvm-project/commit/2bf423b0218c9583e3a372950a34facbf93e63d3
DIFF: https://github.com/llvm/llvm-project/commit/2bf423b0218c9583e3a372950a34facbf93e63d3.diff
LOG: [mlir] RewriterGen NativeCodeCall matcher with ConstantOp matcher
Added an underlying matcher for generic constant ops. This
included a rewriter of RewriterGen to make variable use more
clear.
Differential Revision: https://reviews.llvm.org/D89161
Added:
mlir/test/mlir-tblgen/rewriter-errors.td
Modified:
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Pattern.h
mlir/lib/TableGen/Pattern.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/mlir-tblgen/pattern.mlir
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 82dc6a456f29..72b3b1ab41f5 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2351,6 +2351,8 @@ class NativeCodeCall<string expr> {
string expression = expr;
}
+def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($0->getResult(0), m_Constant(&$1)))">;
+
//===----------------------------------------------------------------------===//
// Rewrite directives
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 4fc2ae762a66..98c5d9b18f5d 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -252,6 +252,9 @@ class SymbolInfoMap {
static SymbolInfo getAttr(const Operator *op, int index) {
return SymbolInfo(op, Kind::Attr, index);
}
+ static SymbolInfo getAttr() {
+ return SymbolInfo(nullptr, Kind::Attr, llvm::None);
+ }
static SymbolInfo getOperand(const Operator *op, int index) {
return SymbolInfo(op, Kind::Operand, index);
}
@@ -319,6 +322,10 @@ class SymbolInfoMap {
// is already bound.
bool bindValue(StringRef symbol);
+ // Registers the given `symbol` as bound to an attr. Returns false if `symbol`
+ // is already bound.
+ bool bindAttr(StringRef symbol);
+
// Returns true if the given `symbol` is bound.
bool contains(StringRef symbol) const;
@@ -421,6 +428,9 @@ class Pattern {
std::vector<IdentifierLine> getLocation() const;
private:
+ // Helper function to verify variabld binding.
+ void verifyBind(bool result, StringRef symbolName);
+
// Recursively collects all bound symbols inside the DAG tree rooted
// at `tree` and updates the given `infoMap`.
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 448f70359bd0..7044677fad36 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -216,9 +216,13 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
switch (kind) {
case Kind::Attr: {
- auto type =
- op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
- return std::string(formatv("{0} {1};\n", type, name));
+ if (op) {
+ auto type =
+ op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
+ return std::string(formatv("{0} {1};\n", type, name));
+ }
+ // TODO(suderman): Use a more exact type when available.
+ return std::string(formatv("Attribute {0};\n", name));
}
case Kind::Operand: {
// Use operand range for captured operands (to support potential variadic
@@ -394,6 +398,11 @@ bool SymbolInfoMap::bindValue(StringRef symbol) {
return symbolInfoMap.count(inserted->first) == 1;
}
+bool SymbolInfoMap::bindAttr(StringRef symbol) {
+ auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getAttr());
+ return symbolInfoMap.count(inserted->first) == 1;
+}
+
bool SymbolInfoMap::contains(StringRef symbol) const {
return find(symbol) != symbolInfoMap.end();
}
@@ -558,15 +567,15 @@ std::vector<AppliedConstraint> Pattern::getConstraints() const {
for (auto it : *listInit) {
auto *dagInit = dyn_cast<llvm::DagInit>(it);
if (!dagInit)
- PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity "
- "constraints should be DAG nodes");
+ PrintFatalError(&def, "all elements in Pattern multi-entity "
+ "constraints should be DAG nodes");
std::vector<std::string> entities;
entities.reserve(dagInit->arg_size());
for (auto *argName : dagInit->getArgNames()) {
if (!argName) {
PrintFatalError(
- def.getLoc(),
+ &def,
"operands to additional constraints can only be symbol references");
}
entities.push_back(std::string(argName->getValue()));
@@ -584,7 +593,7 @@ int Pattern::getBenefit() const {
int initBenefit = getSourcePattern().getNumOps();
llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
- PrintFatalError(def.getLoc(),
+ PrintFatalError(&def,
"The 'addBenefit' takes and only takes one integer value");
}
return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
@@ -603,64 +612,120 @@ std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
return result;
}
+void Pattern::verifyBind(bool result, StringRef symbolName) {
+ if (!result) {
+ auto err = formatv("symbol '{0}' bound more than once", symbolName);
+ PrintFatalError(&def, err);
+ }
+}
+
void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
bool isSrcPattern) {
auto treeName = tree.getSymbol();
- if (!tree.isOperation()) {
+ auto numTreeArgs = tree.getNumArgs();
+
+ if (tree.isNativeCodeCall()) {
if (!treeName.empty()) {
PrintFatalError(
- def.getLoc(),
- formatv("binding symbol '{0}' to non-operation unsupported right now",
- treeName));
+ &def,
+ formatv(
+ "binding symbol '{0}' to native code call unsupported right now",
+ treeName));
}
- return;
- }
- auto &op = getDialectOp(tree);
- auto numOpArgs = op.getNumArgs();
- auto numTreeArgs = tree.getNumArgs();
-
- // The pattern might have the last argument specifying the location.
- bool hasLocDirective = false;
- if (numTreeArgs != 0) {
- if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
- hasLocDirective = lastArg.isLocationDirective();
- }
+ for (int i = 0; i != numTreeArgs; ++i) {
+ if (auto treeArg = tree.getArgAsNestedDag(i)) {
+ // This DAG node argument is a DAG node itself. Go inside recursively.
+ collectBoundSymbols(treeArg, infoMap, isSrcPattern);
+ continue;
+ }
- if (numOpArgs != numTreeArgs - hasLocDirective) {
- auto err = formatv("op '{0}' argument number mismatch: "
- "{1} in pattern vs. {2} in definition",
- op.getOperationName(), numTreeArgs, numOpArgs);
- PrintFatalError(def.getLoc(), err);
- }
+ if (!isSrcPattern)
+ continue;
- // The name attached to the DAG node's operator is for representing the
- // results generated from this op. It should be remembered as bound results.
- if (!treeName.empty()) {
- LLVM_DEBUG(llvm::dbgs()
- << "found symbol bound to op result: " << treeName << '\n');
- if (!infoMap.bindOpResult(treeName, op))
- PrintFatalError(def.getLoc(),
- formatv("symbol '{0}' bound more than once", treeName));
- }
-
- for (int i = 0; i != numTreeArgs; ++i) {
- if (auto treeArg = tree.getArgAsNestedDag(i)) {
- // This DAG node argument is a DAG node itself. Go inside recursively.
- collectBoundSymbols(treeArg, infoMap, isSrcPattern);
- } else if (isSrcPattern) {
- // We can only bind symbols to op arguments in source pattern. Those
+ // We can only bind symbols to arguments in source pattern. Those
// symbols are referenced in result patterns.
auto treeArgName = tree.getArgName(i);
+
// `$_` is a special symbol meaning ignore the current argument.
if (!treeArgName.empty() && treeArgName != "_") {
- LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
- << treeArgName << '\n');
- if (!infoMap.bindOpArgument(treeArgName, op, i)) {
- auto err = formatv("symbol '{0}' bound more than once", treeArgName);
- PrintFatalError(def.getLoc(), err);
+ if (tree.isNestedDagArg(i)) {
+ auto err = formatv("cannot bind '{0}' for nested native call arg",
+ treeArgName);
+ PrintFatalError(&def, err);
}
+
+ DagLeaf leaf = tree.getArgAsLeaf(i);
+ auto constraint = leaf.getAsConstraint();
+ bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
+ leaf.isConstantAttr() ||
+ constraint.getKind() == Constraint::Kind::CK_Attr;
+
+ if (isAttr) {
+ verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
+ continue;
+ }
+
+ verifyBind(infoMap.bindValue(treeArgName), treeArgName);
}
}
+
+ return;
+ }
+
+ if (tree.isOperation()) {
+ auto &op = getDialectOp(tree);
+ auto numOpArgs = op.getNumArgs();
+
+ // The pattern might have the last argument specifying the location.
+ bool hasLocDirective = false;
+ if (numTreeArgs != 0) {
+ if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
+ hasLocDirective = lastArg.isLocationDirective();
+ }
+
+ if (numOpArgs != numTreeArgs - hasLocDirective) {
+ auto err = formatv("op '{0}' argument number mismatch: "
+ "{1} in pattern vs. {2} in definition",
+ op.getOperationName(), numTreeArgs, numOpArgs);
+ PrintFatalError(&def, err);
+ }
+
+ // The name attached to the DAG node's operator is for representing the
+ // results generated from this op. It should be remembered as bound results.
+ if (!treeName.empty()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "found symbol bound to op result: " << treeName << '\n');
+ verifyBind(infoMap.bindOpResult(treeName, op), treeName);
+ }
+
+ for (int i = 0; i != numTreeArgs; ++i) {
+ if (auto treeArg = tree.getArgAsNestedDag(i)) {
+ // This DAG node argument is a DAG node itself. Go inside recursively.
+ collectBoundSymbols(treeArg, infoMap, isSrcPattern);
+ continue;
+ }
+
+ if (isSrcPattern) {
+ // We can only bind symbols to op arguments in source pattern. Those
+ // symbols are referenced in result patterns.
+ auto treeArgName = tree.getArgName(i);
+ // `$_` is a special symbol meaning ignore the current argument.
+ if (!treeArgName.empty() && treeArgName != "_") {
+ LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
+ << treeArgName << '\n');
+ verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName);
+ }
+ }
+ }
+ return;
+ }
+
+ if (!treeName.empty()) {
+ PrintFatalError(
+ &def, formatv("binding symbol '{0}' to non-operation/native code call "
+ "unsupported right now",
+ treeName));
}
+ return;
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 3bfb82495ce1..d34e997644a5 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -615,6 +615,10 @@ OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
return operand();
}
+OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
+ return getValue();
+}
+
LogicalResult TestOpWithVariadicResultsAndFolder::fold(
ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
for (Value input : this->operands()) {
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index aef39a9e19fe..fcc677361dcc 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -799,6 +799,22 @@ def TestOpWithRegionPattern : TEST_Op<"op_with_region_pattern"> {
let hasCanonicalizer = 1;
}
+def TestOpConstant : TEST_Op<"constant", [ConstantLike, NoSideEffect]> {
+ let arguments = (ins AnyAttr:$value);
+ let results = (outs AnyType);
+ let extraClassDeclaration = [{
+ Attribute getValue() { return getAttr("value"); }
+ }];
+
+ let hasFolder = 1;
+}
+
+def OpR : TEST_Op<"op_r">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>;
+def OpS : TEST_Op<"op_s">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>;
+
+def : Pat<(OpR $input1, (ConstantLikeMatcher I32Attr:$input2)),
+ (OpS:$unused $input1, $input2)>;
+
// Op for testing trivial removal via folding of op with inner ops and no uses.
def TestOpWithRegionFoldNoSideEffect : TEST_Op<
"op_with_region_fold_no_side_effect", [NoSideEffect]> {
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 32d618d9008e..282d31065549 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -9,6 +9,7 @@
#include "TestDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 5986be6240f9..616e116cb170 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -248,6 +248,58 @@ func @verifyUnitAttr() -> (i32, i32) {
return %0, %1 : i32, i32
}
+//===----------------------------------------------------------------------===//
+// Test Constant Matching
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: testConstOp
+func @testConstOp() -> (i32) {
+ // CHECK-NEXT: [[C0:%.+]] = constant 1
+ %0 = "test.constant"() {value = 1 : i32} : () -> i32
+
+ // CHECK-NEXT: return [[C0]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: testConstOpUsed
+func @testConstOpUsed() -> (i32) {
+ // CHECK-NEXT: [[C0:%.+]] = constant 1
+ %0 = "test.constant"() {value = 1 : i32} : () -> i32
+
+ // CHECK-NEXT: [[V0:%.+]] = "test.op_s"([[C0]])
+ %1 = "test.op_s"(%0) {value = 1 : i32} : (i32) -> i32
+
+ // CHECK-NEXT: return [[V0]]
+ return %1 : i32
+}
+
+// CHECK-LABEL: testConstOpReplaced
+func @testConstOpReplaced() -> (i32) {
+ // CHECK-NEXT: [[C0:%.+]] = constant 1
+ %0 = "test.constant"() {value = 1 : i32} : () -> i32
+ %1 = "test.constant"() {value = 2 : i32} : () -> i32
+
+ // CHECK: [[V0:%.+]] = "test.op_s"([[C0]]) {value = 2 : i32}
+ %2 = "test.op_r"(%0, %1) : (i32, i32) -> i32
+
+ // CHECK: [[V0]]
+ return %2 : i32
+}
+// CHECK-LABEL: testConstOpMatchFailure
+func @testConstOpMatchFailure() -> (i64) {
+ // CHECK-DAG: [[C0:%.+]] = constant 1
+ %0 = "test.constant"() {value = 1 : i64} : () -> i64
+
+ // CHECK-DAG: [[C1:%.+]] = constant 2
+ %1 = "test.constant"() {value = 2 : i64} : () -> i64
+
+ // CHECK: [[V0:%.+]] = "test.op_r"([[C0]], [[C1]])
+ %2 = "test.op_r"(%0, %1) : (i64, i64) -> i64
+
+ // CHECK: [[V0]]
+ return %2 : i64
+}
+
//===----------------------------------------------------------------------===//
// Test Enum Attributes
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/rewriter-errors.td b/mlir/test/mlir-tblgen/rewriter-errors.td
new file mode 100644
index 000000000000..eeb049482b88
--- /dev/null
+++ b/mlir/test/mlir-tblgen/rewriter-errors.td
@@ -0,0 +1,29 @@
+// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s
+// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s
+
+include "mlir/IR/OpBase.td"
+
+// Check using the dialect name as the namespace
+def A_Dialect : Dialect {
+ let name = "a";
+}
+
+class A_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<A_Dialect, mnemonic, traits>;
+
+def OpA : A_Op<"op_a">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>;
+def OpB : A_Op<"op_b">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>;
+
+#ifdef ERROR1
+def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
+// ERROR1: [[@LINE+1]]:1: error: binding symbol 'error' to native code call unsupported right now
+def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg),
+ (OpB $val, $arg)>;
+#endif
+
+#ifdef ERROR2
+def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
+// ERROR2: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for
+def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg),
+ (OpB $val, $arg)>;
+#endif
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 7bff3e3b40b6..5521eea38252 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -63,7 +63,7 @@ class PatternEmitter {
private:
// Emits the code for matching ops.
- void emitMatchLogic(DagNode tree);
+ void emitMatchLogic(DagNode tree, StringRef opName);
// Emits the code for rewriting ops.
void emitRewriteLogic();
@@ -72,26 +72,34 @@ class PatternEmitter {
// Match utilities
//===--------------------------------------------------------------------===//
+ // Emits C++ statements for matching the DAG structure.
+ void emitMatch(DagNode tree, StringRef name, int depth);
+
+ // Emits C++ statements for matching using a native code call.
+ void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
+
// Emits C++ statements for matching the op constrained by the given DAG
- // `tree`.
- void emitOpMatch(DagNode tree, int depth);
+ // `tree` returning the op's variable name.
+ void emitOpMatch(DagNode tree, StringRef opName, int depth);
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an operand.
- void emitOperandMatch(DagNode tree, int argIndex, int depth);
+ void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
+ int depth);
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an attribute.
- void emitAttributeMatch(DagNode tree, int argIndex, int depth);
+ void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
+ int depth);
// Emits C++ for checking a match with a corresponding match failure
// diagnostic.
- void emitMatchCheck(int depth, const FmtObjectBase &matchFmt,
+ void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
const llvm::formatv_object_base &failureFmt);
// Emits C++ for checking a match with a corresponding match failure
// diagnostics.
- void emitMatchCheck(int depth, const std::string &matchStr,
+ void emitMatchCheck(StringRef opName, const std::string &matchStr,
const std::string &failureStr);
//===--------------------------------------------------------------------===//
@@ -113,7 +121,7 @@ class PatternEmitter {
// Emits the C++ statement to replace the matched DAG with a value built via
// calling native C++ code.
- std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
+ std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
// Returns the symbol of the old value serving as the replacement.
StringRef handleReplaceWithValue(DagNode tree);
@@ -140,12 +148,13 @@ class PatternEmitter {
// Emits the concrete arguments used to call an op's builder.
void supplyValuesForOpArgs(DagNode node,
- const ChildNodeIndexNameMap &childNodeNames);
+ const ChildNodeIndexNameMap &childNodeNames,
+ int depth);
// Emits the local variables for holding all values as a whole and all named
// attributes as a whole to be used for creating an op.
void createAggregateLocalVarsForOpArgs(
- DagNode node, const ChildNodeIndexNameMap &childNodeNames);
+ DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
// Returns the C++ expression to construct a constant attribute of the given
// `value` for the given attribute kind `attr`.
@@ -218,21 +227,114 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
}
// Helper function to match patterns.
-void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
+void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
+ if (tree.isNativeCodeCall()) {
+ emitNativeCodeMatch(tree, name, depth);
+ return;
+ }
+
+ if (tree.isOperation()) {
+ emitOpMatch(tree, name, depth);
+ return;
+ }
+
+ PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
+}
+
+// Helper function to match patterns.
+void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
+ int depth) {
+ LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
+ LLVM_DEBUG(tree.print(llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << '\n');
+
+ // TODO(suderman): iterate through arguments, determine their types, output
+ // names.
+ SmallVector<std::string, 8> capture(8);
+ if (tree.getNumArgs() > 8) {
+ PrintFatalError(loc,
+ "unsupported NativeCodeCall matcher argument numbers: " +
+ Twine(tree.getNumArgs()));
+ }
+
+ raw_indented_ostream::DelimitedScope scope(os);
+
+ for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+ std::string argName = formatv("arg{0}_{1}", depth, i);
+ if (DagNode argTree = tree.getArgAsNestedDag(i)) {
+ os << "Value " << argName << ";\n";
+ } else {
+ auto leaf = tree.getArgAsLeaf(i);
+ if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
+ os << "Attribute " << argName << ";\n";
+ } else if (leaf.isOperandMatcher()) {
+ os << "Operation " << argName << ";\n";
+ }
+ }
+
+ capture[i] = std::move(argName);
+ }
+
+ bool hasLocationDirective;
+ std::string locToUse;
+ std::tie(hasLocationDirective, locToUse) = getLocation(tree);
+
+ auto fmt = tree.getNativeCodeTemplate();
+ auto nativeCodeCall = std::string(tgfmt(
+ fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1],
+ capture[2], capture[3], capture[4], capture[5], capture[6], capture[7]));
+
+ os << "if (failed(" << nativeCodeCall << ")) return failure();\n";
+
+ for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+ auto name = tree.getArgName(i);
+ if (!name.empty() && name != "_") {
+ os << formatv("{0} = {1};\n", name, capture[i]);
+ }
+ }
+
+ for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+ std::string argName = capture[i];
+
+ // Handle nested DAG construct first
+ if (DagNode argTree = tree.getArgAsNestedDag(i)) {
+ PrintFatalError(
+ loc, formatv("Matching nested tree in NativeCodecall not support for "
+ "{0} as arg {1}",
+ argName, i));
+ }
+
+ DagLeaf leaf = tree.getArgAsLeaf(i);
+ auto constraint = leaf.getAsConstraint();
+
+ auto self = formatv("{0}", argName);
+ emitMatchCheck(
+ opName,
+ tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
+ formatv("\"operand {0} of native code call '{1}' failed to satisfy "
+ "constraint: "
+ "'{2}'\"",
+ i, tree.getNativeCodeTemplate(), constraint.getDescription()));
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
+}
+
+// Helper function to match patterns.
+void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
Operator &op = tree.getDialectOp(opMap);
LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
<< op.getOperationName() << "' at depth " << depth
<< '\n');
- int indent = 4 + 2 * depth;
- os.indent(indent) << formatv(
- "auto castedOp{0} = ::llvm::dyn_cast_or_null<{1}>(op{0}); "
- "(void)castedOp{0};\n",
- depth, op.getQualCppClassName());
+ std::string castedName = formatv("castedOp{0}", depth);
+ os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); "
+ "(void){0};\n",
+ castedName, opName, op.getQualCppClassName());
// Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) {
// Skip if there is no defining operation (e.g., arguments to function).
- os << formatv("if (!castedOp{0})\n return failure();\n", depth);
+ os << formatv("if (!{0}) return failure();\n", castedName);
}
if (tree.getNumArgs() != op.getNumArgs()) {
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
@@ -244,10 +346,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
// If the operand's name is set, set to that variable.
auto name = tree.getSymbol();
if (!name.empty())
- os << formatv("{0} = castedOp{1};\n", name, depth);
+ os << formatv("{0} = {1};\n", name, castedName);
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i);
+ std::string argName = formatv("op{0}", depth + 1);
// Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
@@ -262,20 +365,20 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
os << "{\n";
os.indent() << formatv(
- "auto *op{0} = "
- "(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
- depth + 1, depth, i);
- emitOpMatch(argTree, depth + 1);
- os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
+ "auto *{0} = "
+ "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
+ argName, castedName, i);
+ emitMatch(argTree, argName, depth + 1);
+ os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
os.unindent() << "}\n";
continue;
}
// Next handle DAG leaf: operand or attribute
if (opArg.is<NamedTypeConstraint *>()) {
- emitOperandMatch(tree, i, depth);
+ emitOperandMatch(tree, castedName, i, depth);
} else if (opArg.is<NamedAttribute *>()) {
- emitAttributeMatch(tree, i, depth);
+ emitAttributeMatch(tree, opName, i, depth);
} else {
PrintFatalError(loc, "unhandled case when matching op");
}
@@ -285,7 +388,8 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
<< '\n');
}
-void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
+void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
+ int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
auto matcher = tree.getArgAsLeaf(argIndex);
@@ -309,11 +413,10 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
op.getOperationName(), argIndex);
PrintFatalError(loc, error);
}
- auto self =
- formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth,
- argIndex);
+ auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
+ opName, argIndex);
emitMatchCheck(
- depth,
+ opName,
tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
"'{2}'\"",
@@ -333,21 +436,22 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
- os << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
- res->second.getVarName(name), depth, argIndex - numPrevAttrs);
+ os << formatv("{0} = {1}.getODSOperands({2});\n",
+ res->second.getVarName(name), opName,
+ argIndex - numPrevAttrs);
}
}
-void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
+void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
+ int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
const auto &attr = namedAttr->attr;
os << "{\n";
- os.indent() << formatv(
- "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); "
- "(void)tblgen_attr;\n",
- depth, attr.getStorageType(), namedAttr->name);
+ os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
+ "(void)tblgen_attr;\n",
+ opName, attr.getStorageType(), namedAttr->name);
// TODO: This should use getter method to avoid duplication.
if (attr.hasDefaultValue()) {
@@ -360,7 +464,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
// should just capture a mlir::Attribute() to signal the missing state.
// That is precisely what getAttr() returns on missing attributes.
} else {
- emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx),
+ emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
formatv("\"expected op '{0}' to have attribute '{1}' "
"of type '{2}'\"",
op.getOperationName(), namedAttr->name,
@@ -378,7 +482,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
emitMatchCheck(
- depth,
+ opName,
tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
"{2}\"",
@@ -397,24 +501,25 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
}
void PatternEmitter::emitMatchCheck(
- int depth, const FmtObjectBase &matchFmt,
+ StringRef opName, const FmtObjectBase &matchFmt,
const llvm::formatv_object_base &failureFmt) {
- emitMatchCheck(depth, matchFmt.str(), failureFmt.str());
+ emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
}
-void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr,
+void PatternEmitter::emitMatchCheck(StringRef opName,
+ const std::string &matchStr,
const std::string &failureStr) {
+
os << "if (!(" << matchStr << "))";
- os.scope("{\n", "\n}\n").os
- << "return rewriter.notifyMatchFailure(op" << depth
- << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureStr
- << ";\n});";
+ os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
+ << ", [&](::mlir::Diagnostic &diag) {\n diag << "
+ << failureStr << ";\n});";
}
-void PatternEmitter::emitMatchLogic(DagNode tree) {
+void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
int depth = 0;
- emitOpMatch(tree, depth);
+ emitMatch(tree, opName, depth);
for (auto &appliedConstraint : pattern.getConstraints()) {
auto &constraint = appliedConstraint.constraint;
@@ -425,7 +530,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
auto self = formatv("({0}.getType())",
symbolInfoMap.getValueAndRangeUse(entities.front()));
emitMatchCheck(
- depth, tgfmt(condition, &fmtCtx.withSelf(self.str())),
+ opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
entities.front(), constraint.getDescription()));
@@ -447,7 +552,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
self = symbolInfoMap.getValueAndRangeUse(self);
for (; i < 4; ++i)
names.push_back("<unused>");
- emitMatchCheck(depth,
+ emitMatchCheck(opName,
tgfmt(condition, &fmtCtx.withSelf(self), names[0],
names[1], names[2], names[3]),
formatv("\"entities '{0}' failed to satisfy constraint: "
@@ -471,7 +576,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
for (++startRange; startRange != endRange; ++startRange) {
auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
emitMatchCheck(
- depth,
+ opName,
formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
secondOperand));
@@ -567,7 +672,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
os << "// Match\n";
os << "tblgen_ops[0] = op0;\n";
- emitMatchLogic(sourceTree);
+ emitMatchLogic(sourceTree, "op0");
os << "\n// Rewrite\n";
emitRewriteLogic();
@@ -681,7 +786,7 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
}
if (resultTree.isNativeCodeCall()) {
- auto symbol = handleReplaceWithNativeCodeCall(resultTree);
+ auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth);
symbolInfoMap.bindValue(symbol);
return symbol;
}
@@ -798,7 +903,8 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
PrintFatalError(loc, "unhandled case when rewriting op");
}
-std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
+std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
+ int depth) {
LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
LLVM_DEBUG(tree.print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n');
@@ -807,15 +913,20 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
// TODO: replace formatv arguments with the exact specified args.
SmallVector<std::string, 8> attrs(8);
if (tree.getNumArgs() > 8) {
- PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
- Twine(tree.getNumArgs()));
+ PrintFatalError(loc,
+ "unsupported NativeCodeCall replace argument numbers: " +
+ Twine(tree.getNumArgs()));
}
bool hasLocationDirective;
std::string locToUse;
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
- attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
+ if (tree.isNestedDagArg(i)) {
+ attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1);
+ } else {
+ attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
+ }
LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
<< " replacement: " << attrs[i] << "\n");
}
@@ -924,7 +1035,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// create the ops.
// First prepare local variables for op arguments used in builder call.
- createAggregateLocalVarsForOpArgs(tree, childNodeNames);
+ createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
// Then create the op.
os.scope("", "\n}\n").os << formatv(
@@ -948,7 +1059,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
resultOp.getQualCppClassName(), locToUse);
- supplyValuesForOpArgs(tree, childNodeNames);
+ supplyValuesForOpArgs(tree, childNodeNames, depth);
os << "\n );\n}\n";
return resultValue;
}
@@ -959,7 +1070,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// here.
// First prepare local variables for op arguments used in builder call.
- createAggregateLocalVarsForOpArgs(tree, childNodeNames);
+ createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
// Then prepare the result types. We need to specify the types for all
// results.
@@ -1037,7 +1148,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
}
void PatternEmitter::supplyValuesForOpArgs(
- DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
+ DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
Operator &resultOp = node.getDialectOp(opMap);
for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
argIndex != numOpArgs; ++argIndex) {
@@ -1060,7 +1171,7 @@ void PatternEmitter::supplyValuesForOpArgs(
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
os << formatv("/*{0}=*/{1}", opArgName,
- handleReplaceWithNativeCodeCall(subTree));
+ handleReplaceWithNativeCodeCall(subTree, depth));
} else {
auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
@@ -1080,7 +1191,7 @@ void PatternEmitter::supplyValuesForOpArgs(
}
void PatternEmitter::createAggregateLocalVarsForOpArgs(
- DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
+ DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
Operator &resultOp = node.getDialectOp(opMap);
auto scope = os.scope();
@@ -1102,7 +1213,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
os << formatv(addAttrCmd, opArgName,
- handleReplaceWithNativeCodeCall(subTree));
+ handleReplaceWithNativeCodeCall(subTree, depth + 1));
} else {
auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
More information about the Mlir-commits
mailing list