[Mlir-commits] [mlir] d7314b3 - [mlir-tblgen] Support binding multi-results of NativeCodeCall
Chia-hung Duan
llvmlistbot at llvm.org
Tue Jul 20 20:35:55 PDT 2021
Author: Chia-hung Duan
Date: 2021-07-21T11:23:22+08:00
New Revision: d7314b3c094e96fbca7b195eab5fa521bda5fe22
URL: https://github.com/llvm/llvm-project/commit/d7314b3c094e96fbca7b195eab5fa521bda5fe22
DIFF: https://github.com/llvm/llvm-project/commit/d7314b3c094e96fbca7b195eab5fa521bda5fe22.diff
LOG: [mlir-tblgen] Support binding multi-results of NativeCodeCall
We are able to bind NativeCodeCall result as binding operation. To make
table-gen have better understanding in the form of helper function,
we need to specify the number of return values in the NativeCodeCall
template. A VoidNativeCodeCall is added for void case.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D102160
Added:
Modified:
mlir/docs/DeclarativeRewrites.md
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Pattern.h
mlir/lib/TableGen/Pattern.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/mlir-tblgen/pattern.mlir
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index 5815035ca77e8..c3518972bb0ef 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -377,9 +377,6 @@ template. The string can be an arbitrary C++ expression that evaluates into some
C++ object expected at the `NativeCodeCall` site (here it would be expecting an
array attribute). Typically the string should be a function call.
-Note that currently `NativeCodeCall` must return no more than one value or
-attribute. This might change in the future.
-
##### `NativeCodeCall` placeholders
In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N` and `$N...`.
@@ -428,6 +425,30 @@ parameters at the `NativeCodeCall` use site. For example, if we define
`SomeCall : NativeCodeCall<"someFn($1...)">` and use it like `(SomeCall $in0,
$in1, $in2)`, then this will be translated into C++ call `someFn($in1, $in2)`.
+##### `NativeCodeCall` binding multi-results
+
+To bind multi-results and access the N-th result with `$<name>__N`, specify the
+number of return values in the template. Note that only `Value` type is
+supported for multiple results binding. For example,
+
+```tablegen
+
+def PackAttrs : NativeCodeCall<"packAttrs($0, $1)", 2>;
+def : Pattern<(TwoResultOp $attr1, $attr2),
+ [(OneResultOp (PackAttr:$res__0, $attr1, $attr2)),
+ (OneResultOp $res__1)]>;
+
+```
+
+Use `NativeCodeCallVoid` for case has no return value.
+
+The correct number of returned value specified in NativeCodeCall is important.
+It will be used to verify the consistency of the number of result values.
+Additionally, `mlir-tblgen` will try to capture the return value of
+NativeCodeCall in the generated code so that it will trigger a later compilation
+error if a NativeCodeCall that doesn't return a result isn't labeled with 0
+returns.
+
##### Customizing entire op building
`NativeCodeCall` is not only limited to transforming arguments for building an
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 1168dd12d6f09..0d2ff17870019 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2565,11 +2565,20 @@ class Pat<dag pattern, dag result, list<dag> preds = [],
// If used as a DAG node, i.e., `(NativeCodeCall<"..."> <arg0>, ..., <argN>)`,
// then positional placeholders are also supported; placeholder `$N` in the
// wrapped C++ expression will be replaced by `<argN>`.
+//
+// ## Bind multiple results
+//
+// To bind multi-results and access the N-th result with `$<name>__N`, specify
+// the number of return values in the template. Note that only `Value` type is
+// supported for multiple results binding.
-class NativeCodeCall<string expr> {
+class NativeCodeCall<string expr, int returns = 1> {
string expression = expr;
+ int numReturns = returns;
}
+class NativeCodeCallVoid<string expr> : NativeCodeCall<expr, 0>;
+
def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($_self->getResult(0), m_Constant(&$0)))">;
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 4b397c7c02bf6..047abbef27acf 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -100,6 +100,11 @@ class DagLeaf {
// Precondition: isNativeCodeCall()
StringRef getNativeCodeTemplate() const;
+ // Returns the number of values will be returned by the native helper
+ // function.
+ // Precondition: isNativeCodeCall()
+ int getNumReturnsOfNativeCode() const;
+
// Returns the string associated with the leaf.
// Precondition: isStringAttr()
std::string getStringAttr() const;
@@ -181,6 +186,11 @@ class DagNode {
// Precondition: isNativeCodeCall()
StringRef getNativeCodeTemplate() const;
+ // Returns the number of values will be returned by the native helper
+ // function.
+ // Precondition: isNativeCodeCall()
+ int getNumReturnsOfNativeCode() const;
+
void print(raw_ostream &os) const;
private:
@@ -242,30 +252,32 @@ class SymbolInfoMap {
// DagNode and DagLeaf are accessed by value which means it can't be used as
// identifier here. Use an opaque pointer type instead.
- using DagAndIndex = std::pair<const void *, int>;
+ using DagAndConstant = std::pair<const void *, int>;
// What kind of entity this symbol represents:
// * Attr: op attribute
// * Operand: op operand
// * Result: op result
// * Value: a value not attached to an op (e.g., from NativeCodeCall)
- enum class Kind : uint8_t { Attr, Operand, Result, Value };
+ // * MultipleValues: a pack of values not attached to an op (e.g., from
+ // NativeCodeCall). This kind supports indexing.
+ enum class Kind : uint8_t { Attr, Operand, Result, Value, MultipleValues };
- // Creates a SymbolInfo instance. `dagAndIndex` is only used for `Attr` and
- // `Operand` so should be llvm::None for `Result` and `Value` kind.
+ // Creates a SymbolInfo instance. `dagAndConstant` is only used for `Attr`
+ // and `Operand` so should be llvm::None for `Result` and `Value` kind.
SymbolInfo(const Operator *op, Kind kind,
- Optional<DagAndIndex> dagAndIndex);
+ Optional<DagAndConstant> dagAndConstant);
// Static methods for creating SymbolInfo.
static SymbolInfo getAttr(const Operator *op, int index) {
- return SymbolInfo(op, Kind::Attr, DagAndIndex(nullptr, index));
+ return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index));
}
static SymbolInfo getAttr() {
return SymbolInfo(nullptr, Kind::Attr, llvm::None);
}
static SymbolInfo getOperand(DagNode node, const Operator *op, int index) {
return SymbolInfo(op, Kind::Operand,
- DagAndIndex(node.getAsOpaquePointer(), index));
+ DagAndConstant(node.getAsOpaquePointer(), index));
}
static SymbolInfo getResult(const Operator *op) {
return SymbolInfo(op, Kind::Result, llvm::None);
@@ -273,6 +285,10 @@ class SymbolInfoMap {
static SymbolInfo getValue() {
return SymbolInfo(nullptr, Kind::Value, llvm::None);
}
+ static SymbolInfo getMultipleValues(int numValues) {
+ return SymbolInfo(nullptr, Kind::MultipleValues,
+ DagAndConstant(nullptr, numValues));
+ }
// Returns the number of static values this symbol corresponds to.
// A static value is an operand/result declared in ODS. Normally a symbol
@@ -298,13 +314,21 @@ class SymbolInfoMap {
std::string getAllRangeUse(StringRef name, int index, const char *fmt,
const char *separator) const;
+ // The argument index (for `Attr` and `Operand` only)
+ int getArgIndex() const { return (*dagAndConstant).second; }
+
+ // The number of values in the MultipleValue
+ int getSize() const { return (*dagAndConstant).second; }
+
const Operator *op; // The op where the bound entity belongs
Kind kind; // The kind of the bound entity
- // The pair of DagNode pointer and argument index (for `Attr` and `Operand`
- // only). Note that operands may be bound to the same symbol, use the
- // DagNode and index to distinguish them. For `Attr`, the Dag part will be
- // nullptr.
- Optional<DagAndIndex> dagAndIndex;
+
+ // The pair of DagNode pointer and constant value (for `Attr`, `Operand` and
+ // the size of MultipleValue symbol). Note that operands may be bound to the
+ // same symbol, use the DagNode and index to distinguish them. For `Attr`
+ // and MultipleValue, the Dag part will be nullptr.
+ Optional<DagAndConstant> dagAndConstant;
+
// Alternative name for the symbol. It is used in case the name
// is not unique. Applicable for `Operand` only.
Optional<std::string> alternativeName;
@@ -331,10 +355,17 @@ class SymbolInfoMap {
// `symbol` is already bound.
bool bindOpResult(StringRef symbol, const Operator &op);
- // Registers the given `symbol` as bound to a value. Returns false if `symbol`
- // is already bound.
+ // A helper function for dispatching target value binding functions.
+ bool bindValues(StringRef symbol, int numValues = 1);
+
+ // Registers the given `symbol` as bound to the Value(s). Returns false if
+ // `symbol` is already bound.
bool bindValue(StringRef symbol);
+ // Registers the given `symbol` as bound to a MultipleValue. Return false if
+ // `symbol` is already bound.
+ bool bindMultipleValues(StringRef symbol, int numValues);
+
// Registers the given `symbol` as bound to an attr. Returns false if `symbol`
// is already bound.
bool bindAttr(StringRef symbol);
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index a9b03519fb540..e7d5a774ad84e 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -83,6 +83,11 @@ llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
}
+int DagLeaf::getNumReturnsOfNativeCode() const {
+ assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
+ return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
+}
+
std::string DagLeaf::getStringAttr() const {
assert(isStringAttr() && "the DAG leaf must be string attribute");
return def->getAsUnquotedString();
@@ -119,6 +124,13 @@ llvm::StringRef DagNode::getNativeCodeTemplate() const {
->getValueAsString("expression");
}
+int DagNode::getNumReturnsOfNativeCode() const {
+ assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
+ return cast<llvm::DefInit>(node->getOperator())
+ ->getDef()
+ ->getValueAsInt("numReturns");
+}
+
llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
@@ -193,8 +205,8 @@ StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
}
SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
- Optional<DagAndIndex> dagAndIndex)
- : op(op), kind(kind), dagAndIndex(dagAndIndex) {}
+ Optional<DagAndConstant> dagAndConstant)
+ : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
switch (kind) {
@@ -204,6 +216,8 @@ int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
return 1;
case Kind::Result:
return op->getNumResults();
+ case Kind::MultipleValues:
+ return getSize();
}
llvm_unreachable("unknown kind");
}
@@ -217,7 +231,7 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
switch (kind) {
case Kind::Attr: {
if (op) {
- auto type = op->getArg((*dagAndIndex).second)
+ auto type = op->getArg(getArgIndex())
.get<NamedAttribute *>()
->attr.getStorageType();
return std::string(formatv("{0} {1};\n", type, name));
@@ -235,6 +249,14 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
case Kind::Value: {
return std::string(formatv("::mlir::Value {0};\n", name));
}
+ case Kind::MultipleValues: {
+ // This is for the variable used in the source pattern. Each named value in
+ // source pattern will only be bound to a Value. The others in the result
+ // pattern may be associated with multiple Values as we will use `auto` to
+ // do the type inference.
+ return std::string(formatv(
+ "::mlir::Value {0}_raw; ::mlir::ValueRange {0}({0}_raw);\n", name));
+ }
case Kind::Result: {
// Use the op itself for captured results.
return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
@@ -255,8 +277,7 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
}
case Kind::Operand: {
assert(index < 0);
- auto *operand =
- op->getArg((*dagAndIndex).second).get<NamedTypeConstraint *>();
+ auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
// If this operand is variadic, then return a range. Otherwise, return the
// value itself.
if (operand->isVariableLength()) {
@@ -311,6 +332,21 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
return std::string(repl);
}
+ case Kind::MultipleValues: {
+ assert(op == nullptr);
+ assert(index < getSize());
+ if (index >= 0) {
+ std::string repl =
+ formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
+ LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
+ return repl;
+ }
+ // If it doesn't specify certain element, unpack them all.
+ auto repl =
+ formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
+ LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
+ return std::string(repl);
+ }
}
llvm_unreachable("unknown kind");
}
@@ -353,6 +389,20 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
return std::string(repl);
}
+ case Kind::MultipleValues: {
+ assert(op == nullptr);
+ assert(index < getSize());
+ if (index >= 0) {
+ std::string repl =
+ formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
+ LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
+ return repl;
+ }
+ auto repl =
+ formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
+ LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
+ return std::string(repl);
+ }
}
llvm_unreachable("unknown kind");
}
@@ -395,11 +445,25 @@ bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
return symbolInfoMap.count(inserted->first) == 1;
}
+bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
+ std::string name = getValuePackName(symbol).str();
+ if (numValues > 1)
+ return bindMultipleValues(name, numValues);
+ return bindValue(name);
+}
+
bool SymbolInfoMap::bindValue(StringRef symbol) {
auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
return symbolInfoMap.count(inserted->first) == 1;
}
+bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
+ std::string name = getValuePackName(symbol).str();
+ auto inserted =
+ symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
+ return symbolInfoMap.count(inserted->first) == 1;
+}
+
bool SymbolInfoMap::bindAttr(StringRef symbol) {
auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
return symbolInfoMap.count(inserted->first) == 1;
@@ -423,11 +487,9 @@ SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex);
- for (auto it = range.first; it != range.second; ++it) {
- if (it->second.dagAndIndex == symbolInfo.dagAndIndex) {
+ for (auto it = range.first; it != range.second; ++it)
+ if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
return it;
- }
- }
return symbolInfoMap.end();
}
@@ -633,7 +695,9 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
if (!isSrcPattern) {
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
<< treeName << '\n');
- verifyBind(infoMap.bindValue(treeName), treeName);
+ verifyBind(
+ infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
+ treeName);
} else {
PrintFatalError(&def,
formatv("binding symbol '{0}' to NativecodeCall in "
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 1e163a9c2fad8..c8b656a75e0db 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -857,7 +857,7 @@ def OpNativeCodeCall3 : TEST_Op<"native_code_call3"> {
// Test that NativeCodeCall is not ignored if it is not used to directly
// replace the matched root op.
def : Pattern<(OpNativeCodeCall3 $input),
- [(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input),
+ [(NativeCodeCallVoid<"createOpI($_builder, $_loc, $0)"> $input),
(OpK)]>;
def OpNativeCodeCall4 : TEST_Op<"native_code_call4"> {
@@ -874,6 +874,19 @@ def BindNativeCodeCallResult : NativeCodeCall<"bindNativeCodeCallResult($0)">;
def : Pat<(OpNativeCodeCall4 (GetFirstI32Result $ret)),
(OpNativeCodeCall5 (BindNativeCodeCallResult:$native $ret), $native)>;
+def OpNativeCodeCall6 : TEST_Op<"native_code_call6"> {
+ let arguments = (ins I32:$input1, I32:$input2);
+ let results = (outs I32:$output1, I32:$output2);
+}
+def OpNativeCodeCall7 : TEST_Op<"native_code_call7"> {
+ let arguments = (ins I32:$input);
+ let results = (outs I32);
+}
+def BindMultipleNativeCodeCallResult : NativeCodeCall<"bindMultipleNativeCodeCallResult($0, $1)", 2>;
+def : Pattern<(OpNativeCodeCall6 $arg1, $arg2),
+ [(OpNativeCodeCall7 (BindMultipleNativeCodeCallResult:$native__0 $arg1, $arg2)),
+ (OpNativeCodeCall7 $native__1)]>;
+
// Test AllAttrConstraintsOf.
def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> {
let arguments = (ins I64ArrayAttr:$attr);
@@ -1033,7 +1046,7 @@ def OpSymbolBindingNoResult : TEST_Op<"symbol_binding_no_result", []> {
// Test that we can bind to an op without results and reference it later.
def : Pat<(OpSymbolBindingNoResult:$op $operand),
- (NativeCodeCall<"handleNoResultOp($_builder, $0)"> $op)>;
+ (NativeCodeCallVoid<"handleNoResultOp($_builder, $0)"> $op)>;
//===----------------------------------------------------------------------===//
// Test Patterns (Attributes)
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 78ad6fefea55e..a6b0d970792f7 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -44,6 +44,11 @@ static bool getFirstI32Result(Operation *op, Value &value) {
static Value bindNativeCodeCallResult(Value value) { return value; }
+static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
+ Value input2) {
+ return SmallVector<Value, 2>({input2, input1});
+}
+
// Test that natives calls are only called once during rewrites.
// OpM_Test will return Pi, increased by 1 for each subsequent calls.
// This let us check the number of times OpM_Test was called by inspecting
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 69140dfb3fd47..8af0ef4a65224 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -102,6 +102,16 @@ func @verifyNativeCodeCallBinding(%arg0 : i32) -> (i32) {
return %1 : i32
}
+// CHECK-LABEL: verifyMultipleNativeCodeCallBinding
+func at verifyMultipleNativeCodeCallBinding(%arg0 : i32) -> (i32) {
+ %0 = "test.op_k"() : () -> (i32)
+ %1 = "test.op_k"() : () -> (i32)
+ // CHECK: %[[A:.*]] = "test.native_code_call7"(%1) : (i32) -> i32
+ // CHECK: %[[A:.*]] = "test.native_code_call7"(%0) : (i32) -> i32
+ %2, %3 = "test.native_code_call6"(%0, %1) : (i32, i32) -> (i32, i32)
+ return %2 : i32
+}
+
// CHECK-LABEL: verifyAllAttrConstraintOf
func @verifyAllAttrConstraintOf() -> (i32, i32, i32) {
// CHECK: "test.all_attr_constraint_of2"
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 9d3e4a93b53d6..5913101b86b27 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -754,7 +754,8 @@ void PatternEmitter::emitRewriteLogic() {
// NativeCodeCall will only be materialized to `os` if it is used. Here
// we are handling auxiliary patterns so we want the side effect even if
// NativeCodeCall is not replacing matched root op's results.
- if (resultTree.isNativeCodeCall())
+ if (resultTree.isNativeCodeCall() &&
+ resultTree.getNumReturnsOfNativeCode() == 0)
os << val << ";\n";
}
@@ -804,11 +805,8 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
"location directive can only be used with op creation");
}
- if (resultTree.isNativeCodeCall()) {
- auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth);
- symbolInfoMap.bindValue(symbol);
- return symbol;
- }
+ if (resultTree.isNativeCodeCall())
+ return handleReplaceWithNativeCodeCall(resultTree, depth);
if (resultTree.isReplaceWithValue())
return handleReplaceWithValue(resultTree).str();
@@ -948,9 +946,39 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
}
std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs);
- if (!tree.getSymbol().empty()) {
- os << formatv("auto {0} = {1};\n", tree.getSymbol(), symbol);
- symbol = tree.getSymbol().str();
+
+ // In general, NativeCodeCall without naming binding don't need this. To
+ // ensure void helper function has been correctly labeled, i.e., use
+ // NativeCodeCallVoid, we cache the result to a local variable so that we will
+ // get a compilation error in the auto-generated file.
+ // Example.
+ // // In the td file
+ // Pat<(...), (NativeCodeCall<Foo> ...)>
+ //
+ // ---
+ //
+ // // In the auto-generated .cpp
+ // ...
+ // // Causes compilation error if Foo() returns void.
+ // auto nativeVar = Foo();
+ // ...
+ if (tree.getNumReturnsOfNativeCode() != 0) {
+ // Determine the local variable name for return value.
+ std::string varName =
+ SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
+ if (varName.empty()) {
+ varName = formatv("nativeVar_{0}", nextValueId++);
+ // Register the local variable for later uses.
+ symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
+ }
+
+ // Catch the return value of helper function.
+ os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
+
+ if (!tree.getSymbol().empty())
+ symbol = tree.getSymbol().str();
+ else
+ symbol = varName;
}
return symbol;
@@ -967,8 +995,10 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
// Otherwise this is an unbound op; we will use all its results.
return pattern.getDialectOp(node).getNumResults();
}
- // TODO: This considers all NativeCodeCall as returning one
- // value. Enhance if multi-value ones are needed.
+
+ if (node.isNativeCodeCall())
+ return node.getNumReturnsOfNativeCode();
+
return 1;
}
@@ -1191,8 +1221,7 @@ void PatternEmitter::supplyValuesForOpArgs(
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
- os << formatv("/*{0}=*/{1}", opArgName,
- handleReplaceWithNativeCodeCall(subTree, depth));
+ os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
} else {
auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
@@ -1233,8 +1262,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
- os << formatv(addAttrCmd, opArgName,
- handleReplaceWithNativeCodeCall(subTree, depth + 1));
+ os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
} else {
auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
More information about the Mlir-commits
mailing list