[Mlir-commits] [mlir] cb8c30d - [DRR] Explicit Return Types in Rewrites

Jacques Pienaar llvmlistbot at llvm.org
Wed Sep 15 14:25:51 PDT 2021


Author: Mogball
Date: 2021-09-15T14:25:29-07:00
New Revision: cb8c30d35dc9eedca4b8073e96f06e9ce8f12192

URL: https://github.com/llvm/llvm-project/commit/cb8c30d35dc9eedca4b8073e96f06e9ce8f12192
DIFF: https://github.com/llvm/llvm-project/commit/cb8c30d35dc9eedca4b8073e96f06e9ce8f12192.diff

LOG: [DRR] Explicit Return Types in Rewrites

Adds a new rewrite directive returnType that can be added at the end of an op's
argument list to explicitly specify return types.

```
(OpX $v0, $v1, (returnType "$_builder.getI32Type()"))
```

Pass in a bound value to copy its return type, or pass a native code call to
dynamically create new types.

```
(OpX $v0, $v1, (returnType $v0, (NativeCodeCall<"..."> $v1)))
```

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D109472

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/Pattern.h
    mlir/lib/TableGen/Pattern.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/test/mlir-tblgen/rewriter-errors.td
    mlir/test/mlir-tblgen/rewriter-indexing.td
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index cb9742428101..ac99ac4d0897 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2637,15 +2637,49 @@ def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($_self->getResult
 // Rewrite directives
 //===----------------------------------------------------------------------===//
 
-// Directive used in result pattern to specify the location of the generated
-// op. This directive must be used as the last argument to the op creation
-// DAG construct. The arguments to location must be previously captured symbol.
-def location;
-
 // Directive used in result pattern to indicate that no new op are generated,
 // so to replace the matched DAG with an existing SSA value.
 def replaceWithValue;
 
+// Directive used in result patterns to specify the location of the generated
+// op. This directive must be used as a trailing argument to op creation or
+// native code calls.
+//
+// Usage:
+// * Create a named location: `(location "myLocation")`
+// * Copy the location of a captured symbol: `(location $arg)`
+// * Create a fused location: `(location "metadata", $arg0, $arg1)`
+
+def location;
+
+// Directive used in result patterns to specify return types for a created op.
+// This allows ops to be created without relying on type inference with
+// `OpTraits` or an op builder with deduction.
+//
+// This directive must be used as a trailing argument to op creation.
+//
+// Specify one return type with a string literal:
+//
+// ```
+// (AnOp $val, (returnType "$_builder.getI32Type()"))
+// ```
+//
+// Pass a captured value to copy its return type:
+//
+// ```
+// (AnOp $val, (returnType $val));
+// ```
+//
+// Pass a native code call inside a DAG to create a new type with arguments.
+//
+// ```
+// (AnOp $val,
+//       (returnType (NativeCodeCall<"$_builder.getTupleType({$0})"> $val)));
+// ```
+//
+// Specify multiple return types with multiple of any of the above.
+
+def returnType;
 
 //===----------------------------------------------------------------------===//
 // Attribute and Type generation

diff  --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 047abbef27ac..a3786cd8e0b8 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -176,6 +176,9 @@ class DagNode {
   // Returns whether this DAG represents the location of an op creation.
   bool isLocationDirective() const;
 
+  // Returns whether this DAG is a return type specifier.
+  bool isReturnTypeDirective() const;
+
   // Returns true if this DAG node is wrapping native code call.
   bool isNativeCodeCall() const;
 

diff  --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index e7d5a774ad84..ce225ed93076 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -54,9 +54,7 @@ bool DagLeaf::isEnumAttrCase() const {
   return isSubClassOf("EnumAttrCaseInfo");
 }
 
-bool DagLeaf::isStringAttr() const {
-  return isa<llvm::StringInit>(def);
-}
+bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
 
 Constraint DagLeaf::getAsConstraint() const {
   assert((isOperandMatcher() || isAttrMatcher()) &&
@@ -114,7 +112,8 @@ bool DagNode::isNativeCodeCall() const {
 }
 
 bool DagNode::isOperation() const {
-  return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective();
+  return !isNativeCodeCall() && !isReplaceWithValue() &&
+         !isLocationDirective() && !isReturnTypeDirective();
 }
 
 llvm::StringRef DagNode::getNativeCodeTemplate() const {
@@ -180,6 +179,11 @@ bool DagNode::isLocationDirective() const {
   return dagOpDef->getName() == "location";
 }
 
+bool DagNode::isReturnTypeDirective() const {
+  auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
+  return dagOpDef->getName() == "returnType";
+}
+
 void DagNode::print(raw_ostream &os) const {
   if (node)
     node->print(os);
@@ -753,14 +757,18 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
     auto &op = getDialectOp(tree);
     auto numOpArgs = op.getNumArgs();
 
-    // The pattern might have the last argument specifying the location.
-    bool hasLocDirective = false;
-    if (numTreeArgs != 0) {
-      if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
-        hasLocDirective = lastArg.isLocationDirective();
+    // The pattern might have trailing directives.
+    int numDirectives = 0;
+    for (int i = numTreeArgs - 1; i >= 0; --i) {
+      if (auto dagArg = tree.getArgAsNestedDag(i)) {
+        if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
+          ++numDirectives;
+        else
+          break;
+      }
     }
 
-    if (numOpArgs != numTreeArgs - hasLocDirective) {
+    if (numOpArgs != numTreeArgs - numDirectives) {
       auto err = formatv("op '{0}' argument number mismatch: "
                          "{1} in pattern vs. {2} in definition",
                          op.getOperationName(), numTreeArgs, numOpArgs);

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index fcc0f2861ccf..e17a76b67a51 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1341,6 +1341,64 @@ def : Pat<(TestLocationSrcOp:$res1
               (location "named")),
             (location "fused", $res2, $res3))>;
 
+//===----------------------------------------------------------------------===//
+// Test Patterns (Type Builders)
+
+def SourceOp : TEST_Op<"source_op"> {
+  let arguments = (ins AnyInteger:$arg, AnyI32Attr:$tag);
+  let results = (outs AnyInteger);
+}
+
+// An op without return type deduction.
+def OpX : TEST_Op<"op_x"> {
+  let arguments = (ins AnyInteger:$input);
+  let results = (outs AnyInteger);
+}
+
+// Test that ops without built-in type deduction can be created in the
+// replacement DAG with an explicitly specified type.
+def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "11">:$attr),
+          (OpX (OpX $val, (returnType "$_builder.getI32Type()")))>;
+// Test NativeCodeCall type builder can accept arguments.
+def SameTypeAs : NativeCodeCall<"$0.getType()">;
+
+def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "22">:$attr),
+          (OpX (OpX $val, (returnType (SameTypeAs $val))))>;
+
+// Test multiple return types.
+def MakeI64Type : NativeCodeCall<"$_builder.getI64Type()">;
+def MakeI32Type : NativeCodeCall<"$_builder.getI32Type()">;
+
+def OneToTwo : TEST_Op<"one_to_two"> {
+  let arguments = (ins AnyInteger);
+  let results = (outs AnyInteger, AnyInteger);
+}
+
+def TwoToOne : TEST_Op<"two_to_one"> {
+  let arguments = (ins AnyInteger, AnyInteger);
+  let results = (outs AnyInteger);
+}
+
+def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "33">:$attr),
+          (TwoToOne (OpX (OneToTwo:$res__0 $val, (returnType (MakeI64Type), (MakeI32Type))), (returnType (MakeI32Type))),
+                    (OpX $res__1, (returnType (MakeI64Type))))>;
+
+// Test copy value return type.
+def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "44">:$attr),
+          (OpX (OpX $val, (returnType $val)))>;
+
+// Test create multiple return types with 
diff erent methods.
+def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "55">:$attr),
+          (TwoToOne (OneToTwo:$res__0 $val, (returnType $val, "$_builder.getI64Type()")), $res__1)>;
+
+//===----------------------------------------------------------------------===//
+// Test Patterns (Trailing Directives)
+
+// Test that we can specify both `location` and `returnType` directives.
+def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "66">:$attr),
+          (TwoToOne (OpX $val, (returnType $val), (location "loc1")),
+                    (OpX $val, (location "loc2"), (returnType $val)))>;
+
 //===----------------------------------------------------------------------===//
 // Test Legalization
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 8af0ef4a6522..4a05df4782c7 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -530,3 +530,56 @@ func @redundantTest(%arg0: i32) -> i32 {
   // CHECK: "test.op_m"(%arg0) {optional_attr = 314159265 : i32} : (i32) -> i32
   return %0 : i32
 }
+
+//===----------------------------------------------------------------------===//
+// Test that ops without type deduction can be created with type builders.
+//===----------------------------------------------------------------------===//
+
+func @explicitReturnTypeTest(%arg0 : i64) -> i8 {
+  %0 = "test.source_op"(%arg0) {tag = 11 : i32} : (i64) -> i8
+  // CHECK: "test.op_x"(%arg0) : (i64) -> i32
+  // CHECK: "test.op_x"(%0) : (i32) -> i8
+  return %0 : i8
+}
+
+func @returnTypeBuilderTest(%arg0 : i1) -> i8 {
+  %0 = "test.source_op"(%arg0) {tag = 22 : i32} : (i1) -> i8
+  // CHECK: "test.op_x"(%arg0) : (i1) -> i1
+  // CHECK: "test.op_x"(%0) : (i1) -> i8
+  return %0 : i8
+}
+
+func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 {
+  %0 = "test.source_op"(%arg0) {tag = 33 : i32} : (i1) -> i1
+  // CHECK: "test.one_to_two"(%arg0) : (i1) -> (i64, i32)
+  // CHECK: "test.op_x"(%0#0) : (i64) -> i32
+  // CHECK: "test.op_x"(%0#1) : (i32) -> i64
+  // CHECK: "test.two_to_one"(%1, %2) : (i32, i64) -> i1
+  return %0 : i1
+}
+
+func @copyValueType(%arg0 : i8) -> i32 {
+  %0 = "test.source_op"(%arg0) {tag = 44 : i32} : (i8) -> i32
+  // CHECK: "test.op_x"(%arg0) : (i8) -> i8
+  // CHECK: "test.op_x"(%0) : (i8) -> i32
+  return %0 : i32
+}
+
+func @multipleReturnTypeDifferent(%arg0 : i1) -> i64 {
+  %0 = "test.source_op"(%arg0) {tag = 55 : i32} : (i1) -> i64
+  // CHECK: "test.one_to_two"(%arg0) : (i1) -> (i1, i64)
+  // CHECK: "test.two_to_one"(%0#0, %0#1) : (i1, i64) -> i64
+  return %0 : i64
+}
+
+//===----------------------------------------------------------------------===//
+// Test that multiple trailing directives can be mixed in patterns.
+//===----------------------------------------------------------------------===//
+
+func @returnTypeAndLocation(%arg0 : i32) -> i1 {
+  %0 = "test.source_op"(%arg0) {tag = 66 : i32} : (i32) -> i1
+  // CHECK: "test.op_x"(%arg0) : (i32) -> i32 loc("loc1")
+  // CHECK: "test.op_x"(%arg0) : (i32) -> i32 loc("loc2")
+  // CHECK: "test.two_to_one"(%0, %1) : (i32, i32) -> i1
+  return %0 : i1
+}

diff  --git a/mlir/test/mlir-tblgen/rewriter-errors.td b/mlir/test/mlir-tblgen/rewriter-errors.td
index 60e4710688e7..3cebe2dfd34f 100644
--- a/mlir/test/mlir-tblgen/rewriter-errors.td
+++ b/mlir/test/mlir-tblgen/rewriter-errors.td
@@ -1,6 +1,8 @@
 // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s
 // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s
 // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR3 %s 2>&1 | FileCheck --check-prefix=ERROR3 %s
+// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR4 %s 2>&1 | FileCheck --check-prefix=ERROR4 %s
+// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR5 %s 2>&1 | FileCheck --check-prefix=ERROR5 %s
 
 include "mlir/IR/OpBase.td"
 
@@ -35,3 +37,15 @@ def NativeMatcher : NativeCodeCall<"success(nativeCall($_self, $0, $1))">;
 def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg),
           (OpB $val, $arg)>;
 #endif
+
+#ifdef ERROR4
+// Check trying to pass op as DAG node inside ReturnTypeFunc fails.
+// ERROR4: [[@LINE+1]]:1: error: nested DAG in `returnType` must be a native code
+def : Pat<(OpB $val, AnyI32Attr:$attr), (OpA (OpA $val, $val, (returnType (OpA $val, $val))), $val)>;
+#endif
+
+#ifdef ERROR5
+// Check that trying to specify explicit types at the root node fails.
+// ERROR5: [[@LINE+1]]:1: error: Cannot specify explicit return types in an op
+def : Pat<(OpB $val, AnyI32Attr:$attr), (OpA $val, $val, (returnType "someType()"))>;
+#endif

diff  --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td
index f4f055e1c0c4..e31d78c2481e 100644
--- a/mlir/test/mlir-tblgen/rewriter-indexing.td
+++ b/mlir/test/mlir-tblgen/rewriter-indexing.td
@@ -90,3 +90,13 @@ def test4 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
 // CHECK: foo(rewriter, (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin()))
 def test5 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
                 (NativeCodeCall<[{ foo($_builder, $3...) }]> $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>;
+
+// Check Pattern with return type builder.
+def SameTypeAs : NativeCodeCall<"$0.getType()">;
+// CHECK: struct test6 : public ::mlir::RewritePattern {
+// CHECK: tblgen_types.push_back((*v2.begin()).getType())
+// CHECK: tblgen_types.push_back(rewriter.getI32Type())
+// CHECK: nativeVar_1 = doSomething((*v3.begin()))
+// CHECK: tblgen_types.push_back(nativeVar_1)
+def test6 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
+                (AOp (AOp $v1, (returnType $v2, "$_builder.getI32Type()", (NativeCodeCall<"doSomething($0)"> $v3))))>;

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 0ef5a5789427..37fe800c2d4c 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -127,12 +127,31 @@ class PatternEmitter {
   // Returns the symbol of the old value serving as the replacement.
   StringRef handleReplaceWithValue(DagNode tree);
 
+  // Trailing directives are used at the end of DAG node argument lists to
+  // specify additional behaviour for op matchers and creators, etc.
+  struct TrailingDirectives {
+    // DAG node containing the `location` directive. Null if there is none.
+    DagNode location;
+
+    // DAG node containing the `returnType` directive. Null if there is none.
+    DagNode returnType;
+
+    // Number of found trailing directives.
+    int numDirectives;
+  };
+
+  // Collect any trailing directives.
+  TrailingDirectives getTrailingDirectives(DagNode tree);
+
   // Returns the location value to use.
-  std::pair<bool, std::string> getLocation(DagNode tree);
+  std::string getLocation(TrailingDirectives &tail);
 
   // Returns the location value to use.
   std::string handleLocationDirective(DagNode tree);
 
+  // Emit return type argument.
+  std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
+
   // Emits the C++ statement to build a new op out of the given DAG `tree` and
   // returns the variable name that this op is assigned to. If the root op in
   // DAG `tree` has a specified name, the created op will be assigned to a
@@ -271,9 +290,10 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
     capture.push_back(std::move(argName));
   }
 
-  bool hasLocationDirective;
-  std::string locToUse;
-  std::tie(hasLocationDirective, locToUse) = getLocation(tree);
+  auto tail = getTrailingDirectives(tree);
+  if (tail.returnType)
+    PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
+  auto locToUse = getLocation(tail);
 
   auto fmt = tree.getNativeCodeTemplate();
   if (fmt.count("$_self") != 1)
@@ -286,14 +306,14 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
   emitMatchCheck(opName, formatv("!failed({0})", nativeCodeCall),
                  formatv("\"{0} return failure\"", nativeCodeCall));
 
-  for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+  for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
     auto name = tree.getArgName(i);
     if (!name.empty() && name != "_") {
       os << formatv("{0} = {1};\n", name, capture[i]);
     }
   }
 
-  for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+  for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
     std::string argName = capture[i];
 
     // Handle nested DAG construct first
@@ -884,6 +904,24 @@ std::string PatternEmitter::handleLocationDirective(DagNode tree) {
   return os.str();
 }
 
+std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
+                                                int depth) {
+  // Nested NativeCodeCall.
+  if (auto dagNode = returnType.getArgAsNestedDag(i)) {
+    if (!dagNode.isNativeCodeCall())
+      PrintFatalError(loc, "nested DAG in `returnType` must be a native code "
+                           "call");
+    return handleReplaceWithNativeCodeCall(dagNode, depth);
+  }
+  // String literal.
+  auto dagLeaf = returnType.getArgAsLeaf(i);
+  if (dagLeaf.isStringAttr())
+    return tgfmt(dagLeaf.getStringAttr(), &fmtCtx);
+  return tgfmt(
+      "$0.getType()", &fmtCtx,
+      handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i)));
+}
+
 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
                                              StringRef patArgName) {
   if (leaf.isStringAttr())
@@ -929,11 +967,12 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
 
   SmallVector<std::string, 16> attrs;
 
-  bool hasLocationDirective;
-  std::string locToUse;
-  std::tie(hasLocationDirective, locToUse) = getLocation(tree);
+  auto tail = getTrailingDirectives(tree);
+  if (tail.returnType)
+    PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
+  auto locToUse = getLocation(tail);
 
-  for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
+  for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
     if (tree.isNestedDagArg(i)) {
       attrs.push_back(
           handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
@@ -1002,18 +1041,49 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
   return 1;
 }
 
-std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) {
-  auto numPatArgs = tree.getNumArgs();
+PatternEmitter::TrailingDirectives
+PatternEmitter::getTrailingDirectives(DagNode tree) {
+  TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0};
 
-  if (numPatArgs != 0) {
-    if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
-      if (lastArg.isLocationDirective()) {
-        return std::make_pair(true, handleLocationDirective(lastArg));
-      }
+  // Look backwards through the arguments.
+  auto numPatArgs = tree.getNumArgs();
+  for (int i = numPatArgs - 1; i >= 0; --i) {
+    auto dagArg = tree.getArgAsNestedDag(i);
+    // A leaf is not a directive. Stop looking.
+    if (!dagArg)
+      break;
+
+    auto isLocation = dagArg.isLocationDirective();
+    auto isReturnType = dagArg.isReturnTypeDirective();
+    // If encountered a DAG node that isn't a trailing directive, stop looking.
+    if (!(isLocation || isReturnType))
+      break;
+    // Save the directive, but error if one of the same type was already
+    // found.
+    ++tail.numDirectives;
+    if (isLocation) {
+      if (tail.location)
+        PrintFatalError(loc, "`location` directive can only be specified "
+                             "once");
+      tail.location = dagArg;
+    } else if (isReturnType) {
+      if (tail.returnType)
+        PrintFatalError(loc, "`returnType` directive can only be specified "
+                             "once");
+      tail.returnType = dagArg;
+    }
   }
 
+  return tail;
+}
+
+std::string
+PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
+  if (tail.location)
+    return handleLocationDirective(tail.location);
+
   // If no explicit location is given, use the default, all fused, location.
-  return std::make_pair(false, "odsLoc");
+  return "odsLoc";
 }
 
 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
@@ -1026,11 +1096,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
   auto numOpArgs = resultOp.getNumArgs();
   auto numPatArgs = tree.getNumArgs();
 
-  bool hasLocationDirective;
-  std::string locToUse;
-  std::tie(hasLocationDirective, locToUse) = getLocation(tree);
+  auto tail = getTrailingDirectives(tree);
+  auto locToUse = getLocation(tail);
 
-  auto inPattern = numPatArgs - hasLocationDirective;
+  auto inPattern = numPatArgs - tail.numDirectives;
   if (numOpArgs != inPattern) {
     PrintFatalError(loc,
                     formatv("resultant op '{0}' argument number mismatch: "
@@ -1045,7 +1114,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
   // First go through all the child nodes who are nested DAG constructs to
   // create ops for them and remember the symbol names for them, so that we can
   // use the results in the current node. This happens in a recursive manner.
-  for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
+  for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
     if (auto child = tree.getArgAsNestedDag(i))
       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
   }
@@ -1080,7 +1149,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
   bool useFirstAttr =
       resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
 
-  if (isSameOperandsAndResultType || useFirstAttr) {
+  if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
     // We know how to deduce the result type for ops with these traits and we've
     // generated builders taking aggregate parameters. Use those builders to
     // create the ops.
@@ -1097,7 +1166,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
 
   bool usePartialResults = valuePackName != resultValue;
 
-  if (usePartialResults || depth > 0 || resultIndex < 0) {
+  if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
     // For these cases (broadcastable ops, op results used both as auxiliary
     // values and replacement values, ops in nested patterns, auxiliary ops), we
     // still need to supply the result types when building the op. But because
@@ -1115,10 +1184,14 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     return resultValue;
   }
 
-  // If depth == 0 and resultIndex >= 0, it means we are replacing the values
-  // generated from the source pattern root op. Then we can use the source
-  // pattern's value types to determine the value type of the generated op
-  // here.
+  // If we are provided explicit return types, use them to build the op.
+  // However, if depth == 0 and resultIndex >= 0, it means we are replacing
+  // the values generated from the source pattern root op. Then we must use the
+  // source pattern's value types to determine the value type of the generated
+  // op here.
+  if (depth == 0 && resultIndex >= 0 && tail.returnType)
+    PrintFatalError(loc, "Cannot specify explicit return types in an op whose "
+                         "return values replace the source pattern's root op");
 
   // First prepare local variables for op arguments used in builder call.
   createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
@@ -1128,11 +1201,20 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
   os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
                          "(void)tblgen_types;\n");
   int numResults = resultOp.getNumResults();
-  if (numResults != 0) {
-    for (int i = 0; i < numResults; ++i)
-      os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
-                    "  tblgen_types.push_back(v.getType());\n}\n",
-                    resultIndex + i);
+  if (tail.returnType) {
+    auto numRetTys = tail.returnType.getNumArgs();
+    for (int i = 0; i < numRetTys; ++i) {
+      auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1);
+      os << "tblgen_types.push_back(" << varName << ");\n";
+    }
+  } else {
+    if (numResults != 0) {
+      // Copy the result types from the source pattern.
+      for (int i = 0; i < numResults; ++i)
+        os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
+                      "  tblgen_types.push_back(v.getType());\n}\n",
+                      resultIndex + i);
+    }
   }
   os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
                 "tblgen_values, tblgen_attrs);\n",


        


More information about the Mlir-commits mailing list