[Mlir-commits] [mlir] e40624a - [mlir][ods] Fix OpFormatGen sometimes not calling inferReturnTypes
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 10 11:36:00 PST 2021
Author: Mogball
Date: 2021-12-10T19:35:56Z
New Revision: e40624ae604fc292cb8a7102b0b91b571b26a32a
URL: https://github.com/llvm/llvm-project/commit/e40624ae604fc292cb8a7102b0b91b571b26a32a
DIFF: https://github.com/llvm/llvm-project/commit/e40624ae604fc292cb8a7102b0b91b571b26a32a.diff
LOG: [mlir][ods] Fix OpFormatGen sometimes not calling inferReturnTypes
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D115522
Added:
Modified:
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 120749e78c83d..655ad2d04c508 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1136,7 +1136,9 @@ def ThreeResultOp : TEST_Op<"three_result"> {
let results = (outs I32:$result1, F32:$result2, F32:$result3);
}
-def AnotherThreeResultOp : TEST_Op<"another_three_result", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+def AnotherThreeResultOp
+ : TEST_Op<"another_three_result",
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let arguments = (ins MultiResultOpEnum:$kind);
let results = (outs I32:$result1, F32:$result2, F32:$result3);
}
@@ -2101,6 +2103,53 @@ def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
}];
}
+// Base class for testing mixing allOperandTypes, allOperands, and
+// inferResultTypes.
+class FormatInferAllTypesBaseOp<string mnemonic, list<OpTrait> traits = []>
+ : TEST_Op<mnemonic, [InferTypeOpInterface] # traits> {
+ let arguments = (ins Variadic<AnyType>:$args);
+ let results = (outs Variadic<AnyType>:$outs);
+ let extraClassDeclaration = [{
+ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+ ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+ ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+ ::mlir::TypeRange operandTypes = operands.getTypes();
+ inferredReturnTypes.assign(operandTypes.begin(), operandTypes.end());
+ return ::mlir::success();
+ }
+ }];
+}
+
+// Test inferReturnTypes is called when allOperandTypes and allOperands is true.
+def FormatInferTypeAllOperandsAndTypesOp
+ : FormatInferAllTypesBaseOp<"format_infer_type_all_operands_and_types"> {
+ let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
+}
+
+// Test inferReturnTypes is called when allOperandTypes is true and there is one
+// ODS operand.
+def FormatInferTypeAllOperandsAndTypesOneOperandOp
+ : FormatInferAllTypesBaseOp<"format_infer_type_all_types_one_operand"> {
+ let assemblyFormat = "`(` $args `)` attr-dict `:` type(operands)";
+}
+
+// Test inferReturnTypes is called when allOperandTypes is true and there are
+// more than one ODS operands.
+def FormatInferTypeAllOperandsAndTypesTwoOperandsOp
+ : FormatInferAllTypesBaseOp<"format_infer_type_all_types_two_operands",
+ [SameVariadicOperandSize]> {
+ let arguments = (ins Variadic<AnyType>:$args0, Variadic<AnyType>:$args1);
+ let assemblyFormat = "`(` $args0 `)` `(` $args1 `)` attr-dict `:` type(operands)";
+}
+
+// Test inferReturnTypes is called when allOperands is true and operand types
+// are separately specified.
+def FormatInferTypeAllTypesOp
+ : FormatInferAllTypesBaseOp<"format_infer_type_all_types"> {
+ let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)";
+}
+
//===----------------------------------------------------------------------===//
// Test SideEffects
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index c3214c7afab4d..c65d216c18b52 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -411,6 +411,18 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
// CHECK: test.format_infer_type
%ignored_res7 = test.format_infer_type
+// CHECK: test.format_infer_type_all_operands_and_types(%[[I64]], %[[I32]]) : i64, i32
+%ignored_res8:2 = test.format_infer_type_all_operands_and_types(%i64, %i32) : i64, i32
+
+// CHECK: test.format_infer_type_all_types_one_operand(%[[I64]], %[[I32]]) : i64, i32
+%ignored_res9:2 = test.format_infer_type_all_types_one_operand(%i64, %i32) : i64, i32
+
+// CHECK: test.format_infer_type_all_types_two_operands(%[[I64]], %[[I32]]) (%[[I64]], %[[I32]]) : i64, i32, i64, i32
+%ignored_res10:4 = test.format_infer_type_all_types_two_operands(%i64, %i32) (%i64, %i32) : i64, i32, i64, i32
+
+// CHECK: test.format_infer_type_all_types(%[[I64]], %[[I32]]) : i64, i32
+%ignored_res11:2 = test.format_infer_type_all_types(%i64, %i32) : i64, i32
+
//===----------------------------------------------------------------------===//
// Check DefaultValuedStrAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 4223c2d2ee585..6203edb213e60 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -424,14 +424,18 @@ struct OperationFormat {
/// Generate the parser code for a specific format element.
void genElementParser(Element *element, MethodBody &body,
FmtContext &attrTypeCtx);
- /// Generate the c++ to resolve the types of operands and results during
+ /// Generate the C++ to resolve the types of operands and results during
/// parsing.
void genParserTypeResolution(Operator &op, MethodBody &body);
- /// Generate the c++ to resolve regions during parsing.
+ /// Generate the C++ to resolve the types of the operands during parsing.
+ void genParserOperandTypeResolution(
+ Operator &op, MethodBody &body,
+ function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
+ /// Generate the C++ to resolve regions during parsing.
void genParserRegionResolution(Operator &op, MethodBody &body);
- /// Generate the c++ to resolve successors during parsing.
+ /// Generate the C++ to resolve successors during parsing.
void genParserSuccessorResolution(Operator &op, MethodBody &body);
- /// Generate the c++ to handling variadic segment size traits.
+ /// Generate the C++ to handling variadic segment size traits.
void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
/// Generate the operation printer from this format.
@@ -1462,17 +1466,25 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
}
}
+ // Emit the operand type resolutions.
+ genParserOperandTypeResolution(op, body, emitTypeResolver);
+
+ // Handle return type inference once all operands have been resolved
+ if (infersResultTypes)
+ body << formatv(inferReturnTypesParserCode, op.getCppClassName());
+}
+
+void OperationFormat::genParserOperandTypeResolution(
+ Operator &op, MethodBody &body,
+ function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) {
// Early exit if there are no operands.
- if (op.getNumOperands() == 0) {
- // Handle return type inference here if there are no operands
- if (infersResultTypes)
- body << formatv(inferReturnTypesParserCode, op.getCppClassName());
+ if (op.getNumOperands() == 0)
return;
- }
- // Handle the case where all operand types are in one group.
+ // Handle the case where all operand types are grouped together with
+ // "types(operands)".
if (allOperandTypes) {
- // If we have all operands together, use the full operand list directly.
+ // If `operands` was specified, use the full operand list directly.
if (allOperands) {
body << " if (parser.resolveOperands(allOperands, allOperandTypes, "
"allOperandLoc, result.operands))\n"
@@ -1496,7 +1508,8 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
<< " return ::mlir::failure();\n";
return;
}
- // Handle the case where all of the operands were grouped together.
+
+ // Handle the case where all operands are grouped together with "operands".
if (allOperands) {
body << " if (parser.resolveOperands(allOperands, ";
@@ -1551,10 +1564,6 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
body << ", " << operand.name << "OperandsLoc";
body << ", result.operands))\n return ::mlir::failure();\n";
}
-
- // Handle return type inference once all operands have been resolved
- if (infersResultTypes)
- body << formatv(inferReturnTypesParserCode, op.getCppClassName());
}
void OperationFormat::genParserRegionResolution(Operator &op,
@@ -1833,7 +1842,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
// keyword.
llvm::BitVector nonKeywordCases(cases.size());
bool hasStrCase = false;
- for (auto it : llvm::enumerate(cases)) {
+ for (auto &it : llvm::enumerate(cases)) {
hasStrCase = it.value().isStrCase();
if (!canFormatStringAsKeyword(it.value().getStr()))
nonKeywordCases.set(it.index());
@@ -1860,7 +1869,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
// overlap with other cases. For simplicity sake, only allow cases with a
// single bit value.
if (enumAttr.isBitEnum()) {
- for (auto it : llvm::enumerate(cases)) {
+ for (auto &it : llvm::enumerate(cases)) {
int64_t value = it.value().getValue();
if (value < 0 || !llvm::isPowerOf2_64(value))
nonKeywordCases.set(it.index());
@@ -1873,7 +1882,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
body << " switch (caseValue) {\n";
StringRef cppNamespace = enumAttr.getCppNamespace();
StringRef enumName = enumAttr.getEnumClassName();
- for (auto it : llvm::enumerate(cases)) {
+ for (auto &it : llvm::enumerate(cases)) {
if (nonKeywordCases.test(it.index()))
continue;
StringRef symbol = it.value().getSymbol();
More information about the Mlir-commits
mailing list