[Mlir-commits] [mlir] 1760d8b - [mlir][ODS] Support result type inference in custom assembly format
Daniel Resnick
llvmlistbot at llvm.org
Mon Oct 11 13:09:23 PDT 2021
Author: Daniel Resnick
Date: 2021-10-11T14:07:56-06:00
New Revision: 1760d8b36b4804758a9a4801edc1d97c0ba4f25c
URL: https://github.com/llvm/llvm-project/commit/1760d8b36b4804758a9a4801edc1d97c0ba4f25c
DIFF: https://github.com/llvm/llvm-project/commit/1760d8b36b4804758a9a4801edc1d97c0ba4f25c.diff
LOG: [mlir][ODS] Support result type inference in custom assembly format
Operations that have the InferTypeOpInterface trait can now omit the return
types in their custom assembly formats.
Differential Revision: https://reviews.llvm.org/D111326
Added:
Modified:
mlir/docs/OpDefinitions.md
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-format-spec.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 398121c8dfa0c..ec9e6fdc80ddb 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -929,6 +929,11 @@ these equal constraints to discern the types of missing variables. The currently
supported traits are: `AllTypesMatch`, `TypesMatchWith`, `SameTypeOperands`, and
`SameOperandsAndResultType`.
+* InferTypeOpInterface
+
+Operations that implement `InferTypeOpInterface` can omit their result types in
+their assembly format since the result types can be inferred from the operands.
+
### `hasCanonicalizer`
This boolean field indicate whether canonicalization patterns have been defined
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 858ce7df5de0a..a94ec6dced389 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2021,6 +2021,24 @@ def FormatTypesMatchContextOp : TEST_Op<"format_types_match_context", [
let assemblyFormat = "attr-dict $value `:` type($value)";
}
+//===----------------------------------------------------------------------===//
+// InferTypeOpInterface type inference in assembly format
+
+def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
+ let results = (outs AnyType);
+ let assemblyFormat = "attr-dict";
+
+ let extraClassDeclaration = [{
+ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+ ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+ ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+ inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
+ return ::mlir::success();
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Test SideEffects
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 8c6bb09f34a37..e4d0f9e3b21f8 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -3,6 +3,7 @@
// This file contains tests for the specification of the declarative op format.
include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
def TestDialect : Dialect {
let name = "test";
@@ -566,4 +567,6 @@ def ZCoverageValidH : TestFormat_Op<[{
operands type($result) attr-dict
}], [AllTypesMatch<["operand", "result"]>]>,
Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;
-
+def ZCoverageValidI : TestFormat_Op<[{
+ operands type(operands) attr-dict
+}], [InferTypeOpInterface]>, Arguments<(ins Variadic<I64>:$inputs)>, Results<(outs I64:$result)>;
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index ccaedba466597..4be71a9eee1f7 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -354,3 +354,10 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
// CHECK: test.format_types_match_context %[[I64]] : i64
%ignored_res6 = test.format_types_match_context %i64 : i64
+
+//===----------------------------------------------------------------------===//
+// InferTypeOpInterface type inference
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_infer_type
+%ignored_res7 = test.format_infer_type
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index bffd161fb7ee3..708b9b145851c 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -441,7 +441,8 @@ struct OperationFormat {
};
OperationFormat(const Operator &op)
- : allOperands(false), allOperandTypes(false), allResultTypes(false) {
+ : allOperands(false), allOperandTypes(false), allResultTypes(false),
+ infersResultTypes(false) {
operandTypes.resize(op.getNumOperands(), TypeResolution());
resultTypes.resize(op.getNumResults(), TypeResolution());
@@ -482,6 +483,9 @@ struct OperationFormat {
/// contains these, it can not contain individual type resolvers.
bool allOperands, allOperandTypes, allResultTypes;
+ /// A flag indicating if this operation infers its result types
+ bool infersResultTypes;
+
/// A flag indicating if this operation has the SingleBlockImplicitTerminator
/// trait.
bool hasImplicitTermTrait;
@@ -682,6 +686,19 @@ const char *const functionalTypeParserCode = R"(
{1}Types = {0}__{1}_functionType.getResults();
)";
+/// The code snippet used to generate a parser call to infer return types.
+///
+/// {0}: The operation class name
+const char *const inferReturnTypesParserCode = R"(
+ ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
+ if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
+ result.location, result.operands,
+ result.attributes.getDictionary(parser.getContext()),
+ result.regions, inferredReturnTypes)))
+ return ::mlir::failure();
+ result.addTypes(inferredReturnTypes);
+)";
+
/// The code snippet used to generate a parser call for a region list.
///
/// {0}: The name for the region list.
@@ -1437,19 +1454,25 @@ void OperationFormat::genParserTypeResolution(Operator &op,
};
// Resolve each of the result types.
- if (allResultTypes) {
- body << " result.addTypes(allResultTypes);\n";
- } else {
- for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
- body << " result.addTypes(";
- emitTypeResolver(resultTypes[i], op.getResultName(i));
- body << ");\n";
+ if (!infersResultTypes) {
+ if (allResultTypes) {
+ body << " result.addTypes(allResultTypes);\n";
+ } else {
+ for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
+ body << " result.addTypes(";
+ emitTypeResolver(resultTypes[i], op.getResultName(i));
+ body << ");\n";
+ }
}
}
// Early exit if there are no operands.
- if (op.getNumOperands() == 0)
+ if (op.getNumOperands() == 0) {
+ // Handle return type inference here if there are no operands
+ if (infersResultTypes)
+ body << formatv(inferReturnTypesParserCode, op.getCppClassName());
return;
+ }
// Handle the case where all operand types are in one group.
if (allOperandTypes) {
@@ -1532,6 +1555,10 @@ void OperationFormat::genParserTypeResolution(Operator &op,
body << ", " << operand.name << "OperandsLoc";
body << ", result.operands))\n return ::mlir::failure();\n";
}
+
+ // Handle return type inference once all operands have been resolved
+ if (infersResultTypes)
+ body << formatv(inferReturnTypesParserCode, op.getCppClassName());
}
void OperationFormat::genParserRegionResolution(Operator &op,
@@ -2478,6 +2505,7 @@ class FormatParser {
// during parsing.
bool hasAttrDict = false;
bool hasAllRegions = false, hasAllSuccessors = false;
+ bool canInferResultTypes = false;
llvm::SmallBitVector seenOperandTypes, seenResultTypes;
llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
@@ -2515,6 +2543,9 @@ LogicalResult FormatParser::parse() {
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
} else if (def.isSubClassOf("TypesMatchWith")) {
handleTypesMatchConstraint(variableTyResolver, def);
+ } else if (def.getName() == "InferTypeOpInterface" &&
+ !op.allResultTypesKnown()) {
+ canInferResultTypes = true;
}
}
@@ -2684,6 +2715,14 @@ LogicalResult FormatParser::verifyResults(
if (fmt.allResultTypes)
return ::mlir::success();
+ // If no result types are specified and we can infer them, infer all result
+ // types
+ if (op.getNumResults() > 0 && seenResultTypes.count() == 0 &&
+ canInferResultTypes) {
+ fmt.infersResultTypes = true;
+ return ::mlir::success();
+ }
+
// Check that all of the result types can be inferred.
auto &buildableTypes = fmt.buildableTypes;
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
More information about the Mlir-commits
mailing list