[Mlir-commits] [mlir] cb8c30d - [DRR] Explicit Return Types in Rewrites
Jacques Pienaar
llvmlistbot at llvm.org
Wed Sep 15 14:25:51 PDT 2021
Author: Mogball
Date: 2021-09-15T14:25:29-07:00
New Revision: cb8c30d35dc9eedca4b8073e96f06e9ce8f12192
URL: https://github.com/llvm/llvm-project/commit/cb8c30d35dc9eedca4b8073e96f06e9ce8f12192
DIFF: https://github.com/llvm/llvm-project/commit/cb8c30d35dc9eedca4b8073e96f06e9ce8f12192.diff
LOG: [DRR] Explicit Return Types in Rewrites
Adds a new rewrite directive returnType that can be added at the end of an op's
argument list to explicitly specify return types.
```
(OpX $v0, $v1, (returnType "$_builder.getI32Type()"))
```
Pass in a bound value to copy its return type, or pass a native code call to
dynamically create new types.
```
(OpX $v0, $v1, (returnType $v0, (NativeCodeCall<"..."> $v1)))
```
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D109472
Added:
Modified:
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Pattern.h
mlir/lib/TableGen/Pattern.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir
mlir/test/mlir-tblgen/rewriter-errors.td
mlir/test/mlir-tblgen/rewriter-indexing.td
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index cb9742428101..ac99ac4d0897 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2637,15 +2637,49 @@ def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($_self->getResult
// Rewrite directives
//===----------------------------------------------------------------------===//
-// Directive used in result pattern to specify the location of the generated
-// op. This directive must be used as the last argument to the op creation
-// DAG construct. The arguments to location must be previously captured symbol.
-def location;
-
// Directive used in result pattern to indicate that no new op are generated,
// so to replace the matched DAG with an existing SSA value.
def replaceWithValue;
+// Directive used in result patterns to specify the location of the generated
+// op. This directive must be used as a trailing argument to op creation or
+// native code calls.
+//
+// Usage:
+// * Create a named location: `(location "myLocation")`
+// * Copy the location of a captured symbol: `(location $arg)`
+// * Create a fused location: `(location "metadata", $arg0, $arg1)`
+
+def location;
+
+// Directive used in result patterns to specify return types for a created op.
+// This allows ops to be created without relying on type inference with
+// `OpTraits` or an op builder with deduction.
+//
+// This directive must be used as a trailing argument to op creation.
+//
+// Specify one return type with a string literal:
+//
+// ```
+// (AnOp $val, (returnType "$_builder.getI32Type()"))
+// ```
+//
+// Pass a captured value to copy its return type:
+//
+// ```
+// (AnOp $val, (returnType $val));
+// ```
+//
+// Pass a native code call inside a DAG to create a new type with arguments.
+//
+// ```
+// (AnOp $val,
+// (returnType (NativeCodeCall<"$_builder.getTupleType({$0})"> $val)));
+// ```
+//
+// Specify multiple return types with multiple of any of the above.
+
+def returnType;
//===----------------------------------------------------------------------===//
// Attribute and Type generation
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 047abbef27ac..a3786cd8e0b8 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -176,6 +176,9 @@ class DagNode {
// Returns whether this DAG represents the location of an op creation.
bool isLocationDirective() const;
+ // Returns whether this DAG is a return type specifier.
+ bool isReturnTypeDirective() const;
+
// Returns true if this DAG node is wrapping native code call.
bool isNativeCodeCall() const;
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index e7d5a774ad84..ce225ed93076 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -54,9 +54,7 @@ bool DagLeaf::isEnumAttrCase() const {
return isSubClassOf("EnumAttrCaseInfo");
}
-bool DagLeaf::isStringAttr() const {
- return isa<llvm::StringInit>(def);
-}
+bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
Constraint DagLeaf::getAsConstraint() const {
assert((isOperandMatcher() || isAttrMatcher()) &&
@@ -114,7 +112,8 @@ bool DagNode::isNativeCodeCall() const {
}
bool DagNode::isOperation() const {
- return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective();
+ return !isNativeCodeCall() && !isReplaceWithValue() &&
+ !isLocationDirective() && !isReturnTypeDirective();
}
llvm::StringRef DagNode::getNativeCodeTemplate() const {
@@ -180,6 +179,11 @@ bool DagNode::isLocationDirective() const {
return dagOpDef->getName() == "location";
}
+bool DagNode::isReturnTypeDirective() const {
+ auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
+ return dagOpDef->getName() == "returnType";
+}
+
void DagNode::print(raw_ostream &os) const {
if (node)
node->print(os);
@@ -753,14 +757,18 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
auto &op = getDialectOp(tree);
auto numOpArgs = op.getNumArgs();
- // The pattern might have the last argument specifying the location.
- bool hasLocDirective = false;
- if (numTreeArgs != 0) {
- if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
- hasLocDirective = lastArg.isLocationDirective();
+ // The pattern might have trailing directives.
+ int numDirectives = 0;
+ for (int i = numTreeArgs - 1; i >= 0; --i) {
+ if (auto dagArg = tree.getArgAsNestedDag(i)) {
+ if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
+ ++numDirectives;
+ else
+ break;
+ }
}
- if (numOpArgs != numTreeArgs - hasLocDirective) {
+ if (numOpArgs != numTreeArgs - numDirectives) {
auto err = formatv("op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
op.getOperationName(), numTreeArgs, numOpArgs);
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index fcc0f2861ccf..e17a76b67a51 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1341,6 +1341,64 @@ def : Pat<(TestLocationSrcOp:$res1
(location "named")),
(location "fused", $res2, $res3))>;
+//===----------------------------------------------------------------------===//
+// Test Patterns (Type Builders)
+
+def SourceOp : TEST_Op<"source_op"> {
+ let arguments = (ins AnyInteger:$arg, AnyI32Attr:$tag);
+ let results = (outs AnyInteger);
+}
+
+// An op without return type deduction.
+def OpX : TEST_Op<"op_x"> {
+ let arguments = (ins AnyInteger:$input);
+ let results = (outs AnyInteger);
+}
+
+// Test that ops without built-in type deduction can be created in the
+// replacement DAG with an explicitly specified type.
+def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "11">:$attr),
+ (OpX (OpX $val, (returnType "$_builder.getI32Type()")))>;
+// Test NativeCodeCall type builder can accept arguments.
+def SameTypeAs : NativeCodeCall<"$0.getType()">;
+
+def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "22">:$attr),
+ (OpX (OpX $val, (returnType (SameTypeAs $val))))>;
+
+// Test multiple return types.
+def MakeI64Type : NativeCodeCall<"$_builder.getI64Type()">;
+def MakeI32Type : NativeCodeCall<"$_builder.getI32Type()">;
+
+def OneToTwo : TEST_Op<"one_to_two"> {
+ let arguments = (ins AnyInteger);
+ let results = (outs AnyInteger, AnyInteger);
+}
+
+def TwoToOne : TEST_Op<"two_to_one"> {
+ let arguments = (ins AnyInteger, AnyInteger);
+ let results = (outs AnyInteger);
+}
+
+def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "33">:$attr),
+ (TwoToOne (OpX (OneToTwo:$res__0 $val, (returnType (MakeI64Type), (MakeI32Type))), (returnType (MakeI32Type))),
+ (OpX $res__1, (returnType (MakeI64Type))))>;
+
+// Test copy value return type.
+def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "44">:$attr),
+ (OpX (OpX $val, (returnType $val)))>;
+
+// Test create multiple return types with
diff erent methods.
+def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "55">:$attr),
+ (TwoToOne (OneToTwo:$res__0 $val, (returnType $val, "$_builder.getI64Type()")), $res__1)>;
+
+//===----------------------------------------------------------------------===//
+// Test Patterns (Trailing Directives)
+
+// Test that we can specify both `location` and `returnType` directives.
+def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "66">:$attr),
+ (TwoToOne (OpX $val, (returnType $val), (location "loc1")),
+ (OpX $val, (location "loc2"), (returnType $val)))>;
+
//===----------------------------------------------------------------------===//
// Test Legalization
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 8af0ef4a6522..4a05df4782c7 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -530,3 +530,56 @@ func @redundantTest(%arg0: i32) -> i32 {
// CHECK: "test.op_m"(%arg0) {optional_attr = 314159265 : i32} : (i32) -> i32
return %0 : i32
}
+
+//===----------------------------------------------------------------------===//
+// Test that ops without type deduction can be created with type builders.
+//===----------------------------------------------------------------------===//
+
+func @explicitReturnTypeTest(%arg0 : i64) -> i8 {
+ %0 = "test.source_op"(%arg0) {tag = 11 : i32} : (i64) -> i8
+ // CHECK: "test.op_x"(%arg0) : (i64) -> i32
+ // CHECK: "test.op_x"(%0) : (i32) -> i8
+ return %0 : i8
+}
+
+func @returnTypeBuilderTest(%arg0 : i1) -> i8 {
+ %0 = "test.source_op"(%arg0) {tag = 22 : i32} : (i1) -> i8
+ // CHECK: "test.op_x"(%arg0) : (i1) -> i1
+ // CHECK: "test.op_x"(%0) : (i1) -> i8
+ return %0 : i8
+}
+
+func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 {
+ %0 = "test.source_op"(%arg0) {tag = 33 : i32} : (i1) -> i1
+ // CHECK: "test.one_to_two"(%arg0) : (i1) -> (i64, i32)
+ // CHECK: "test.op_x"(%0#0) : (i64) -> i32
+ // CHECK: "test.op_x"(%0#1) : (i32) -> i64
+ // CHECK: "test.two_to_one"(%1, %2) : (i32, i64) -> i1
+ return %0 : i1
+}
+
+func @copyValueType(%arg0 : i8) -> i32 {
+ %0 = "test.source_op"(%arg0) {tag = 44 : i32} : (i8) -> i32
+ // CHECK: "test.op_x"(%arg0) : (i8) -> i8
+ // CHECK: "test.op_x"(%0) : (i8) -> i32
+ return %0 : i32
+}
+
+func @multipleReturnTypeDifferent(%arg0 : i1) -> i64 {
+ %0 = "test.source_op"(%arg0) {tag = 55 : i32} : (i1) -> i64
+ // CHECK: "test.one_to_two"(%arg0) : (i1) -> (i1, i64)
+ // CHECK: "test.two_to_one"(%0#0, %0#1) : (i1, i64) -> i64
+ return %0 : i64
+}
+
+//===----------------------------------------------------------------------===//
+// Test that multiple trailing directives can be mixed in patterns.
+//===----------------------------------------------------------------------===//
+
+func @returnTypeAndLocation(%arg0 : i32) -> i1 {
+ %0 = "test.source_op"(%arg0) {tag = 66 : i32} : (i32) -> i1
+ // CHECK: "test.op_x"(%arg0) : (i32) -> i32 loc("loc1")
+ // CHECK: "test.op_x"(%arg0) : (i32) -> i32 loc("loc2")
+ // CHECK: "test.two_to_one"(%0, %1) : (i32, i32) -> i1
+ return %0 : i1
+}
diff --git a/mlir/test/mlir-tblgen/rewriter-errors.td b/mlir/test/mlir-tblgen/rewriter-errors.td
index 60e4710688e7..3cebe2dfd34f 100644
--- a/mlir/test/mlir-tblgen/rewriter-errors.td
+++ b/mlir/test/mlir-tblgen/rewriter-errors.td
@@ -1,6 +1,8 @@
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR3 %s 2>&1 | FileCheck --check-prefix=ERROR3 %s
+// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR4 %s 2>&1 | FileCheck --check-prefix=ERROR4 %s
+// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR5 %s 2>&1 | FileCheck --check-prefix=ERROR5 %s
include "mlir/IR/OpBase.td"
@@ -35,3 +37,15 @@ def NativeMatcher : NativeCodeCall<"success(nativeCall($_self, $0, $1))">;
def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg),
(OpB $val, $arg)>;
#endif
+
+#ifdef ERROR4
+// Check trying to pass op as DAG node inside ReturnTypeFunc fails.
+// ERROR4: [[@LINE+1]]:1: error: nested DAG in `returnType` must be a native code
+def : Pat<(OpB $val, AnyI32Attr:$attr), (OpA (OpA $val, $val, (returnType (OpA $val, $val))), $val)>;
+#endif
+
+#ifdef ERROR5
+// Check that trying to specify explicit types at the root node fails.
+// ERROR5: [[@LINE+1]]:1: error: Cannot specify explicit return types in an op
+def : Pat<(OpB $val, AnyI32Attr:$attr), (OpA $val, $val, (returnType "someType()"))>;
+#endif
diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td
index f4f055e1c0c4..e31d78c2481e 100644
--- a/mlir/test/mlir-tblgen/rewriter-indexing.td
+++ b/mlir/test/mlir-tblgen/rewriter-indexing.td
@@ -90,3 +90,13 @@ def test4 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
// CHECK: foo(rewriter, (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin()))
def test5 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
(NativeCodeCall<[{ foo($_builder, $3...) }]> $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>;
+
+// Check Pattern with return type builder.
+def SameTypeAs : NativeCodeCall<"$0.getType()">;
+// CHECK: struct test6 : public ::mlir::RewritePattern {
+// CHECK: tblgen_types.push_back((*v2.begin()).getType())
+// CHECK: tblgen_types.push_back(rewriter.getI32Type())
+// CHECK: nativeVar_1 = doSomething((*v3.begin()))
+// CHECK: tblgen_types.push_back(nativeVar_1)
+def test6 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
+ (AOp (AOp $v1, (returnType $v2, "$_builder.getI32Type()", (NativeCodeCall<"doSomething($0)"> $v3))))>;
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 0ef5a5789427..37fe800c2d4c 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -127,12 +127,31 @@ class PatternEmitter {
// Returns the symbol of the old value serving as the replacement.
StringRef handleReplaceWithValue(DagNode tree);
+ // Trailing directives are used at the end of DAG node argument lists to
+ // specify additional behaviour for op matchers and creators, etc.
+ struct TrailingDirectives {
+ // DAG node containing the `location` directive. Null if there is none.
+ DagNode location;
+
+ // DAG node containing the `returnType` directive. Null if there is none.
+ DagNode returnType;
+
+ // Number of found trailing directives.
+ int numDirectives;
+ };
+
+ // Collect any trailing directives.
+ TrailingDirectives getTrailingDirectives(DagNode tree);
+
// Returns the location value to use.
- std::pair<bool, std::string> getLocation(DagNode tree);
+ std::string getLocation(TrailingDirectives &tail);
// Returns the location value to use.
std::string handleLocationDirective(DagNode tree);
+ // Emit return type argument.
+ std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
+
// Emits the C++ statement to build a new op out of the given DAG `tree` and
// returns the variable name that this op is assigned to. If the root op in
// DAG `tree` has a specified name, the created op will be assigned to a
@@ -271,9 +290,10 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
capture.push_back(std::move(argName));
}
- bool hasLocationDirective;
- std::string locToUse;
- std::tie(hasLocationDirective, locToUse) = getLocation(tree);
+ auto tail = getTrailingDirectives(tree);
+ if (tail.returnType)
+ PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
+ auto locToUse = getLocation(tail);
auto fmt = tree.getNativeCodeTemplate();
if (fmt.count("$_self") != 1)
@@ -286,14 +306,14 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
emitMatchCheck(opName, formatv("!failed({0})", nativeCodeCall),
formatv("\"{0} return failure\"", nativeCodeCall));
- for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+ for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
auto name = tree.getArgName(i);
if (!name.empty() && name != "_") {
os << formatv("{0} = {1};\n", name, capture[i]);
}
}
- for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+ for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
std::string argName = capture[i];
// Handle nested DAG construct first
@@ -884,6 +904,24 @@ std::string PatternEmitter::handleLocationDirective(DagNode tree) {
return os.str();
}
+std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
+ int depth) {
+ // Nested NativeCodeCall.
+ if (auto dagNode = returnType.getArgAsNestedDag(i)) {
+ if (!dagNode.isNativeCodeCall())
+ PrintFatalError(loc, "nested DAG in `returnType` must be a native code "
+ "call");
+ return handleReplaceWithNativeCodeCall(dagNode, depth);
+ }
+ // String literal.
+ auto dagLeaf = returnType.getArgAsLeaf(i);
+ if (dagLeaf.isStringAttr())
+ return tgfmt(dagLeaf.getStringAttr(), &fmtCtx);
+ return tgfmt(
+ "$0.getType()", &fmtCtx,
+ handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i)));
+}
+
std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
StringRef patArgName) {
if (leaf.isStringAttr())
@@ -929,11 +967,12 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
SmallVector<std::string, 16> attrs;
- bool hasLocationDirective;
- std::string locToUse;
- std::tie(hasLocationDirective, locToUse) = getLocation(tree);
+ auto tail = getTrailingDirectives(tree);
+ if (tail.returnType)
+ PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
+ auto locToUse = getLocation(tail);
- for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
+ for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
if (tree.isNestedDagArg(i)) {
attrs.push_back(
handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
@@ -1002,18 +1041,49 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
return 1;
}
-std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) {
- auto numPatArgs = tree.getNumArgs();
+PatternEmitter::TrailingDirectives
+PatternEmitter::getTrailingDirectives(DagNode tree) {
+ TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0};
- if (numPatArgs != 0) {
- if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
- if (lastArg.isLocationDirective()) {
- return std::make_pair(true, handleLocationDirective(lastArg));
- }
+ // Look backwards through the arguments.
+ auto numPatArgs = tree.getNumArgs();
+ for (int i = numPatArgs - 1; i >= 0; --i) {
+ auto dagArg = tree.getArgAsNestedDag(i);
+ // A leaf is not a directive. Stop looking.
+ if (!dagArg)
+ break;
+
+ auto isLocation = dagArg.isLocationDirective();
+ auto isReturnType = dagArg.isReturnTypeDirective();
+ // If encountered a DAG node that isn't a trailing directive, stop looking.
+ if (!(isLocation || isReturnType))
+ break;
+ // Save the directive, but error if one of the same type was already
+ // found.
+ ++tail.numDirectives;
+ if (isLocation) {
+ if (tail.location)
+ PrintFatalError(loc, "`location` directive can only be specified "
+ "once");
+ tail.location = dagArg;
+ } else if (isReturnType) {
+ if (tail.returnType)
+ PrintFatalError(loc, "`returnType` directive can only be specified "
+ "once");
+ tail.returnType = dagArg;
+ }
}
+ return tail;
+}
+
+std::string
+PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
+ if (tail.location)
+ return handleLocationDirective(tail.location);
+
// If no explicit location is given, use the default, all fused, location.
- return std::make_pair(false, "odsLoc");
+ return "odsLoc";
}
std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
@@ -1026,11 +1096,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
auto numOpArgs = resultOp.getNumArgs();
auto numPatArgs = tree.getNumArgs();
- bool hasLocationDirective;
- std::string locToUse;
- std::tie(hasLocationDirective, locToUse) = getLocation(tree);
+ auto tail = getTrailingDirectives(tree);
+ auto locToUse = getLocation(tail);
- auto inPattern = numPatArgs - hasLocationDirective;
+ auto inPattern = numPatArgs - tail.numDirectives;
if (numOpArgs != inPattern) {
PrintFatalError(loc,
formatv("resultant op '{0}' argument number mismatch: "
@@ -1045,7 +1114,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// First go through all the child nodes who are nested DAG constructs to
// create ops for them and remember the symbol names for them, so that we can
// use the results in the current node. This happens in a recursive manner.
- for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
+ for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
if (auto child = tree.getArgAsNestedDag(i))
childNodeNames[i] = handleResultPattern(child, i, depth + 1);
}
@@ -1080,7 +1149,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
bool useFirstAttr =
resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
- if (isSameOperandsAndResultType || useFirstAttr) {
+ if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
// We know how to deduce the result type for ops with these traits and we've
// generated builders taking aggregate parameters. Use those builders to
// create the ops.
@@ -1097,7 +1166,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
bool usePartialResults = valuePackName != resultValue;
- if (usePartialResults || depth > 0 || resultIndex < 0) {
+ if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
// For these cases (broadcastable ops, op results used both as auxiliary
// values and replacement values, ops in nested patterns, auxiliary ops), we
// still need to supply the result types when building the op. But because
@@ -1115,10 +1184,14 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
return resultValue;
}
- // If depth == 0 and resultIndex >= 0, it means we are replacing the values
- // generated from the source pattern root op. Then we can use the source
- // pattern's value types to determine the value type of the generated op
- // here.
+ // If we are provided explicit return types, use them to build the op.
+ // However, if depth == 0 and resultIndex >= 0, it means we are replacing
+ // the values generated from the source pattern root op. Then we must use the
+ // source pattern's value types to determine the value type of the generated
+ // op here.
+ if (depth == 0 && resultIndex >= 0 && tail.returnType)
+ PrintFatalError(loc, "Cannot specify explicit return types in an op whose "
+ "return values replace the source pattern's root op");
// First prepare local variables for op arguments used in builder call.
createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
@@ -1128,11 +1201,20 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
"(void)tblgen_types;\n");
int numResults = resultOp.getNumResults();
- if (numResults != 0) {
- for (int i = 0; i < numResults; ++i)
- os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
- " tblgen_types.push_back(v.getType());\n}\n",
- resultIndex + i);
+ if (tail.returnType) {
+ auto numRetTys = tail.returnType.getNumArgs();
+ for (int i = 0; i < numRetTys; ++i) {
+ auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1);
+ os << "tblgen_types.push_back(" << varName << ");\n";
+ }
+ } else {
+ if (numResults != 0) {
+ // Copy the result types from the source pattern.
+ for (int i = 0; i < numResults; ++i)
+ os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
+ " tblgen_types.push_back(v.getType());\n}\n",
+ resultIndex + i);
+ }
}
os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
"tblgen_values, tblgen_attrs);\n",
More information about the Mlir-commits
mailing list