[Mlir-commits] [mlir] d7314b3 - [mlir-tblgen] Support binding multi-results of NativeCodeCall

Chia-hung Duan llvmlistbot at llvm.org
Tue Jul 20 20:35:55 PDT 2021


Author: Chia-hung Duan
Date: 2021-07-21T11:23:22+08:00
New Revision: d7314b3c094e96fbca7b195eab5fa521bda5fe22

URL: https://github.com/llvm/llvm-project/commit/d7314b3c094e96fbca7b195eab5fa521bda5fe22
DIFF: https://github.com/llvm/llvm-project/commit/d7314b3c094e96fbca7b195eab5fa521bda5fe22.diff

LOG: [mlir-tblgen] Support binding multi-results of NativeCodeCall

We are able to bind NativeCodeCall result as binding operation. To make
table-gen have better understanding in the form of helper function,
we need to specify the number of return values in the NativeCodeCall
template. A VoidNativeCodeCall is added for void case.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D102160

Added: 
    

Modified: 
    mlir/docs/DeclarativeRewrites.md
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/Pattern.h
    mlir/lib/TableGen/Pattern.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index 5815035ca77e8..c3518972bb0ef 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -377,9 +377,6 @@ template. The string can be an arbitrary C++ expression that evaluates into some
 C++ object expected at the `NativeCodeCall` site (here it would be expecting an
 array attribute). Typically the string should be a function call.
 
-Note that currently `NativeCodeCall` must return no more than one value or
-attribute. This might change in the future.
-
 ##### `NativeCodeCall` placeholders
 
 In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N` and `$N...`.
@@ -428,6 +425,30 @@ parameters at the `NativeCodeCall` use site. For example, if we define
 `SomeCall : NativeCodeCall<"someFn($1...)">` and use it like `(SomeCall $in0,
 $in1, $in2)`, then this will be translated into C++ call `someFn($in1, $in2)`.
 
+##### `NativeCodeCall` binding multi-results
+
+To bind multi-results and access the N-th result with `$<name>__N`, specify the
+number of return values in the template. Note that only `Value` type is
+supported for multiple results binding. For example,
+
+```tablegen
+
+def PackAttrs : NativeCodeCall<"packAttrs($0, $1)", 2>;
+def : Pattern<(TwoResultOp $attr1, $attr2),
+              [(OneResultOp (PackAttr:$res__0, $attr1, $attr2)),
+               (OneResultOp $res__1)]>;
+
+```
+
+Use `NativeCodeCallVoid` for case has no return value.
+
+The correct number of returned value specified in NativeCodeCall is important.
+It will be used to verify the consistency of the number of result values.
+Additionally, `mlir-tblgen` will try to capture the return value of
+NativeCodeCall in the generated code so that it will trigger a later compilation
+error if a NativeCodeCall that doesn't return a result isn't labeled with 0
+returns.
+
 ##### Customizing entire op building
 
 `NativeCodeCall` is not only limited to transforming arguments for building an

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 1168dd12d6f09..0d2ff17870019 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2565,11 +2565,20 @@ class Pat<dag pattern, dag result, list<dag> preds = [],
 // If used as a DAG node, i.e., `(NativeCodeCall<"..."> <arg0>, ..., <argN>)`,
 // then positional placeholders are also supported; placeholder `$N` in the
 // wrapped C++ expression will be replaced by `<argN>`.
+//
+// ## Bind multiple results
+//
+// To bind multi-results and access the N-th result with `$<name>__N`, specify
+// the number of return values in the template. Note that only `Value` type is
+// supported for multiple results binding.
 
-class NativeCodeCall<string expr> {
+class NativeCodeCall<string expr, int returns = 1> {
   string expression = expr;
+  int numReturns = returns;
 }
 
+class NativeCodeCallVoid<string expr> : NativeCodeCall<expr, 0>;
+
 def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($_self->getResult(0), m_Constant(&$0)))">;
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 4b397c7c02bf6..047abbef27acf 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -100,6 +100,11 @@ class DagLeaf {
   // Precondition: isNativeCodeCall()
   StringRef getNativeCodeTemplate() const;
 
+  // Returns the number of values will be returned by the native helper
+  // function.
+  // Precondition: isNativeCodeCall()
+  int getNumReturnsOfNativeCode() const;
+
   // Returns the string associated with the leaf.
   // Precondition: isStringAttr()
   std::string getStringAttr() const;
@@ -181,6 +186,11 @@ class DagNode {
   // Precondition: isNativeCodeCall()
   StringRef getNativeCodeTemplate() const;
 
+  // Returns the number of values will be returned by the native helper
+  // function.
+  // Precondition: isNativeCodeCall()
+  int getNumReturnsOfNativeCode() const;
+
   void print(raw_ostream &os) const;
 
 private:
@@ -242,30 +252,32 @@ class SymbolInfoMap {
 
     // DagNode and DagLeaf are accessed by value which means it can't be used as
     // identifier here. Use an opaque pointer type instead.
-    using DagAndIndex = std::pair<const void *, int>;
+    using DagAndConstant = std::pair<const void *, int>;
 
     // What kind of entity this symbol represents:
     // * Attr: op attribute
     // * Operand: op operand
     // * Result: op result
     // * Value: a value not attached to an op (e.g., from NativeCodeCall)
-    enum class Kind : uint8_t { Attr, Operand, Result, Value };
+    // * MultipleValues: a pack of values not attached to an op (e.g., from
+    //   NativeCodeCall). This kind supports indexing.
+    enum class Kind : uint8_t { Attr, Operand, Result, Value, MultipleValues };
 
-    // Creates a SymbolInfo instance. `dagAndIndex` is only used for `Attr` and
-    // `Operand` so should be llvm::None for `Result` and `Value` kind.
+    // Creates a SymbolInfo instance. `dagAndConstant` is only used for `Attr`
+    // and `Operand` so should be llvm::None for `Result` and `Value` kind.
     SymbolInfo(const Operator *op, Kind kind,
-               Optional<DagAndIndex> dagAndIndex);
+               Optional<DagAndConstant> dagAndConstant);
 
     // Static methods for creating SymbolInfo.
     static SymbolInfo getAttr(const Operator *op, int index) {
-      return SymbolInfo(op, Kind::Attr, DagAndIndex(nullptr, index));
+      return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index));
     }
     static SymbolInfo getAttr() {
       return SymbolInfo(nullptr, Kind::Attr, llvm::None);
     }
     static SymbolInfo getOperand(DagNode node, const Operator *op, int index) {
       return SymbolInfo(op, Kind::Operand,
-                        DagAndIndex(node.getAsOpaquePointer(), index));
+                        DagAndConstant(node.getAsOpaquePointer(), index));
     }
     static SymbolInfo getResult(const Operator *op) {
       return SymbolInfo(op, Kind::Result, llvm::None);
@@ -273,6 +285,10 @@ class SymbolInfoMap {
     static SymbolInfo getValue() {
       return SymbolInfo(nullptr, Kind::Value, llvm::None);
     }
+    static SymbolInfo getMultipleValues(int numValues) {
+      return SymbolInfo(nullptr, Kind::MultipleValues,
+                        DagAndConstant(nullptr, numValues));
+    }
 
     // Returns the number of static values this symbol corresponds to.
     // A static value is an operand/result declared in ODS. Normally a symbol
@@ -298,13 +314,21 @@ class SymbolInfoMap {
     std::string getAllRangeUse(StringRef name, int index, const char *fmt,
                                const char *separator) const;
 
+    // The argument index (for `Attr` and `Operand` only)
+    int getArgIndex() const { return (*dagAndConstant).second; }
+
+    // The number of values in the MultipleValue
+    int getSize() const { return (*dagAndConstant).second; }
+
     const Operator *op; // The op where the bound entity belongs
     Kind kind;          // The kind of the bound entity
-    // The pair of DagNode pointer and argument index (for `Attr` and `Operand`
-    // only). Note that operands may be bound to the same symbol, use the
-    // DagNode and index to distinguish them. For `Attr`, the Dag part will be
-    // nullptr.
-    Optional<DagAndIndex> dagAndIndex;
+
+    // The pair of DagNode pointer and constant value (for `Attr`, `Operand` and
+    // the size of MultipleValue symbol). Note that operands may be bound to the
+    // same symbol, use the DagNode and index to distinguish them. For `Attr`
+    // and MultipleValue, the Dag part will be nullptr.
+    Optional<DagAndConstant> dagAndConstant;
+
     // Alternative name for the symbol. It is used in case the name
     // is not unique. Applicable for `Operand` only.
     Optional<std::string> alternativeName;
@@ -331,10 +355,17 @@ class SymbolInfoMap {
   // `symbol` is already bound.
   bool bindOpResult(StringRef symbol, const Operator &op);
 
-  // Registers the given `symbol` as bound to a value. Returns false if `symbol`
-  // is already bound.
+  // A helper function for dispatching target value binding functions.
+  bool bindValues(StringRef symbol, int numValues = 1);
+
+  // Registers the given `symbol` as bound to the Value(s). Returns false if
+  // `symbol` is already bound.
   bool bindValue(StringRef symbol);
 
+  // Registers the given `symbol` as bound to a MultipleValue. Return false if
+  // `symbol` is already bound.
+  bool bindMultipleValues(StringRef symbol, int numValues);
+
   // Registers the given `symbol` as bound to an attr. Returns false if `symbol`
   // is already bound.
   bool bindAttr(StringRef symbol);

diff  --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index a9b03519fb540..e7d5a774ad84e 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -83,6 +83,11 @@ llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
 }
 
+int DagLeaf::getNumReturnsOfNativeCode() const {
+  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
+  return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
+}
+
 std::string DagLeaf::getStringAttr() const {
   assert(isStringAttr() && "the DAG leaf must be string attribute");
   return def->getAsUnquotedString();
@@ -119,6 +124,13 @@ llvm::StringRef DagNode::getNativeCodeTemplate() const {
       ->getValueAsString("expression");
 }
 
+int DagNode::getNumReturnsOfNativeCode() const {
+  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
+  return cast<llvm::DefInit>(node->getOperator())
+      ->getDef()
+      ->getValueAsInt("numReturns");
+}
+
 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
 
 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
@@ -193,8 +205,8 @@ StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
 }
 
 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
-                                      Optional<DagAndIndex> dagAndIndex)
-    : op(op), kind(kind), dagAndIndex(dagAndIndex) {}
+                                      Optional<DagAndConstant> dagAndConstant)
+    : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
 
 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
   switch (kind) {
@@ -204,6 +216,8 @@ int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
     return 1;
   case Kind::Result:
     return op->getNumResults();
+  case Kind::MultipleValues:
+    return getSize();
   }
   llvm_unreachable("unknown kind");
 }
@@ -217,7 +231,7 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
   switch (kind) {
   case Kind::Attr: {
     if (op) {
-      auto type = op->getArg((*dagAndIndex).second)
+      auto type = op->getArg(getArgIndex())
                       .get<NamedAttribute *>()
                       ->attr.getStorageType();
       return std::string(formatv("{0} {1};\n", type, name));
@@ -235,6 +249,14 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
   case Kind::Value: {
     return std::string(formatv("::mlir::Value {0};\n", name));
   }
+  case Kind::MultipleValues: {
+    // This is for the variable used in the source pattern. Each named value in
+    // source pattern will only be bound to a Value. The others in the result
+    // pattern may be associated with multiple Values as we will use `auto` to
+    // do the type inference.
+    return std::string(formatv(
+        "::mlir::Value {0}_raw; ::mlir::ValueRange {0}({0}_raw);\n", name));
+  }
   case Kind::Result: {
     // Use the op itself for captured results.
     return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
@@ -255,8 +277,7 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
   }
   case Kind::Operand: {
     assert(index < 0);
-    auto *operand =
-        op->getArg((*dagAndIndex).second).get<NamedTypeConstraint *>();
+    auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
     // If this operand is variadic, then return a range. Otherwise, return the
     // value itself.
     if (operand->isVariableLength()) {
@@ -311,6 +332,21 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
     return std::string(repl);
   }
+  case Kind::MultipleValues: {
+    assert(op == nullptr);
+    assert(index < getSize());
+    if (index >= 0) {
+      std::string repl =
+          formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
+      LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
+      return repl;
+    }
+    // If it doesn't specify certain element, unpack them all.
+    auto repl =
+        formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
+    LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
+    return std::string(repl);
+  }
   }
   llvm_unreachable("unknown kind");
 }
@@ -353,6 +389,20 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
     return std::string(repl);
   }
+  case Kind::MultipleValues: {
+    assert(op == nullptr);
+    assert(index < getSize());
+    if (index >= 0) {
+      std::string repl =
+          formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
+      LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
+      return repl;
+    }
+    auto repl =
+        formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
+    LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
+    return std::string(repl);
+  }
   }
   llvm_unreachable("unknown kind");
 }
@@ -395,11 +445,25 @@ bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
   return symbolInfoMap.count(inserted->first) == 1;
 }
 
+bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
+  std::string name = getValuePackName(symbol).str();
+  if (numValues > 1)
+    return bindMultipleValues(name, numValues);
+  return bindValue(name);
+}
+
 bool SymbolInfoMap::bindValue(StringRef symbol) {
   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
   return symbolInfoMap.count(inserted->first) == 1;
 }
 
+bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
+  std::string name = getValuePackName(symbol).str();
+  auto inserted =
+      symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
+  return symbolInfoMap.count(inserted->first) == 1;
+}
+
 bool SymbolInfoMap::bindAttr(StringRef symbol) {
   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
   return symbolInfoMap.count(inserted->first) == 1;
@@ -423,11 +487,9 @@ SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
 
   const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex);
 
-  for (auto it = range.first; it != range.second; ++it) {
-    if (it->second.dagAndIndex == symbolInfo.dagAndIndex) {
+  for (auto it = range.first; it != range.second; ++it)
+    if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
       return it;
-    }
-  }
 
   return symbolInfoMap.end();
 }
@@ -633,7 +695,9 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
       if (!isSrcPattern) {
         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
                                 << treeName << '\n');
-        verifyBind(infoMap.bindValue(treeName), treeName);
+        verifyBind(
+            infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
+            treeName);
       } else {
         PrintFatalError(&def,
                         formatv("binding symbol '{0}' to NativecodeCall in "

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 1e163a9c2fad8..c8b656a75e0db 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -857,7 +857,7 @@ def OpNativeCodeCall3 : TEST_Op<"native_code_call3"> {
 // Test that NativeCodeCall is not ignored if it is not used to directly
 // replace the matched root op.
 def : Pattern<(OpNativeCodeCall3 $input),
-              [(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input),
+              [(NativeCodeCallVoid<"createOpI($_builder, $_loc, $0)"> $input),
                (OpK)]>;
 
 def OpNativeCodeCall4 : TEST_Op<"native_code_call4"> {
@@ -874,6 +874,19 @@ def BindNativeCodeCallResult : NativeCodeCall<"bindNativeCodeCallResult($0)">;
 def : Pat<(OpNativeCodeCall4 (GetFirstI32Result $ret)),
           (OpNativeCodeCall5 (BindNativeCodeCallResult:$native $ret), $native)>;
 
+def OpNativeCodeCall6 : TEST_Op<"native_code_call6"> {
+  let arguments = (ins I32:$input1, I32:$input2);
+  let results = (outs I32:$output1, I32:$output2);
+}
+def OpNativeCodeCall7 : TEST_Op<"native_code_call7"> {
+  let arguments = (ins I32:$input);
+  let results = (outs I32);
+}
+def BindMultipleNativeCodeCallResult : NativeCodeCall<"bindMultipleNativeCodeCallResult($0, $1)", 2>;
+def : Pattern<(OpNativeCodeCall6 $arg1, $arg2),
+              [(OpNativeCodeCall7 (BindMultipleNativeCodeCallResult:$native__0 $arg1, $arg2)),
+               (OpNativeCodeCall7 $native__1)]>;
+
 // Test AllAttrConstraintsOf.
 def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> {
   let arguments = (ins I64ArrayAttr:$attr);
@@ -1033,7 +1046,7 @@ def OpSymbolBindingNoResult : TEST_Op<"symbol_binding_no_result", []> {
 
 // Test that we can bind to an op without results and reference it later.
 def : Pat<(OpSymbolBindingNoResult:$op $operand),
-          (NativeCodeCall<"handleNoResultOp($_builder, $0)"> $op)>;
+          (NativeCodeCallVoid<"handleNoResultOp($_builder, $0)"> $op)>;
 
 //===----------------------------------------------------------------------===//
 // Test Patterns (Attributes)

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 78ad6fefea55e..a6b0d970792f7 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -44,6 +44,11 @@ static bool getFirstI32Result(Operation *op, Value &value) {
 
 static Value bindNativeCodeCallResult(Value value) { return value; }
 
+static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
+                                                              Value input2) {
+  return SmallVector<Value, 2>({input2, input1});
+}
+
 // Test that natives calls are only called once during rewrites.
 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
 // This let us check the number of times OpM_Test was called by inspecting

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 69140dfb3fd47..8af0ef4a65224 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -102,6 +102,16 @@ func @verifyNativeCodeCallBinding(%arg0 : i32) -> (i32) {
   return %1 : i32
 }
 
+// CHECK-LABEL: verifyMultipleNativeCodeCallBinding
+func at verifyMultipleNativeCodeCallBinding(%arg0 : i32) -> (i32) {
+  %0 = "test.op_k"() : () -> (i32)
+  %1 = "test.op_k"() : () -> (i32)
+  // CHECK: %[[A:.*]] = "test.native_code_call7"(%1) : (i32) -> i32
+  // CHECK: %[[A:.*]] = "test.native_code_call7"(%0) : (i32) -> i32
+  %2, %3 = "test.native_code_call6"(%0, %1) : (i32, i32) -> (i32, i32)
+  return %2 : i32
+}
+
 // CHECK-LABEL: verifyAllAttrConstraintOf
 func @verifyAllAttrConstraintOf() -> (i32, i32, i32) {
   // CHECK: "test.all_attr_constraint_of2"

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 9d3e4a93b53d6..5913101b86b27 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -754,7 +754,8 @@ void PatternEmitter::emitRewriteLogic() {
     // NativeCodeCall will only be materialized to `os` if it is used. Here
     // we are handling auxiliary patterns so we want the side effect even if
     // NativeCodeCall is not replacing matched root op's results.
-    if (resultTree.isNativeCodeCall())
+    if (resultTree.isNativeCodeCall() &&
+        resultTree.getNumReturnsOfNativeCode() == 0)
       os << val << ";\n";
   }
 
@@ -804,11 +805,8 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
                     "location directive can only be used with op creation");
   }
 
-  if (resultTree.isNativeCodeCall()) {
-    auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth);
-    symbolInfoMap.bindValue(symbol);
-    return symbol;
-  }
+  if (resultTree.isNativeCodeCall())
+    return handleReplaceWithNativeCodeCall(resultTree, depth);
 
   if (resultTree.isReplaceWithValue())
     return handleReplaceWithValue(resultTree).str();
@@ -948,9 +946,39 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
   }
 
   std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs);
-  if (!tree.getSymbol().empty()) {
-    os << formatv("auto {0} = {1};\n", tree.getSymbol(), symbol);
-    symbol = tree.getSymbol().str();
+
+  // In general, NativeCodeCall without naming binding don't need this. To
+  // ensure void helper function has been correctly labeled, i.e., use
+  // NativeCodeCallVoid, we cache the result to a local variable so that we will
+  // get a compilation error in the auto-generated file.
+  // Example.
+  //   // In the td file
+  //   Pat<(...), (NativeCodeCall<Foo> ...)>
+  //
+  //   ---
+  //
+  //   // In the auto-generated .cpp
+  //   ...
+  //   // Causes compilation error if Foo() returns void.
+  //   auto nativeVar = Foo();
+  //   ...
+  if (tree.getNumReturnsOfNativeCode() != 0) {
+    // Determine the local variable name for return value.
+    std::string varName =
+        SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
+    if (varName.empty()) {
+      varName = formatv("nativeVar_{0}", nextValueId++);
+      // Register the local variable for later uses.
+      symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
+    }
+
+    // Catch the return value of helper function.
+    os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
+
+    if (!tree.getSymbol().empty())
+      symbol = tree.getSymbol().str();
+    else
+      symbol = varName;
   }
 
   return symbol;
@@ -967,8 +995,10 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
     // Otherwise this is an unbound op; we will use all its results.
     return pattern.getDialectOp(node).getNumResults();
   }
-  // TODO: This considers all NativeCodeCall as returning one
-  // value. Enhance if multi-value ones are needed.
+
+  if (node.isNativeCodeCall())
+    return node.getNumReturnsOfNativeCode();
+
   return 1;
 }
 
@@ -1191,8 +1221,7 @@ void PatternEmitter::supplyValuesForOpArgs(
       if (!subTree.isNativeCodeCall())
         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
                              "for creating attribute");
-      os << formatv("/*{0}=*/{1}", opArgName,
-                    handleReplaceWithNativeCodeCall(subTree, depth));
+      os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
     } else {
       auto leaf = node.getArgAsLeaf(argIndex);
       // The argument in the result DAG pattern.
@@ -1233,8 +1262,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
         if (!subTree.isNativeCodeCall())
           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
                                "for creating attribute");
-        os << formatv(addAttrCmd, opArgName,
-                      handleReplaceWithNativeCodeCall(subTree, depth + 1));
+        os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
       } else {
         auto leaf = node.getArgAsLeaf(argIndex);
         // The argument in the result DAG pattern.


        


More information about the Mlir-commits mailing list