[Mlir-commits] [mlir] e40624a - [mlir][ods] Fix OpFormatGen sometimes not calling inferReturnTypes

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 10 11:36:00 PST 2021


Author: Mogball
Date: 2021-12-10T19:35:56Z
New Revision: e40624ae604fc292cb8a7102b0b91b571b26a32a

URL: https://github.com/llvm/llvm-project/commit/e40624ae604fc292cb8a7102b0b91b571b26a32a
DIFF: https://github.com/llvm/llvm-project/commit/e40624ae604fc292cb8a7102b0b91b571b26a32a.diff

LOG: [mlir][ods] Fix OpFormatGen sometimes not calling inferReturnTypes

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D115522

Added: 
    

Modified: 
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 120749e78c83d..655ad2d04c508 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1136,7 +1136,9 @@ def ThreeResultOp : TEST_Op<"three_result"> {
   let results = (outs I32:$result1, F32:$result2, F32:$result3);
 }
 
-def AnotherThreeResultOp : TEST_Op<"another_three_result", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+def AnotherThreeResultOp
+    : TEST_Op<"another_three_result",
+              [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let arguments = (ins MultiResultOpEnum:$kind);
   let results = (outs I32:$result1, F32:$result2, F32:$result3);
 }
@@ -2101,6 +2103,53 @@ def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
    }];
 }
 
+// Base class for testing mixing allOperandTypes, allOperands, and
+// inferResultTypes.
+class FormatInferAllTypesBaseOp<string mnemonic, list<OpTrait> traits = []>
+    : TEST_Op<mnemonic, [InferTypeOpInterface] # traits> {
+  let arguments = (ins Variadic<AnyType>:$args);
+  let results = (outs Variadic<AnyType>:$outs);
+  let extraClassDeclaration = [{
+    static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+          ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+          ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+          ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+      ::mlir::TypeRange operandTypes = operands.getTypes();
+      inferredReturnTypes.assign(operandTypes.begin(), operandTypes.end());
+      return ::mlir::success();
+    }
+   }];
+}
+
+// Test inferReturnTypes is called when allOperandTypes and allOperands is true.
+def FormatInferTypeAllOperandsAndTypesOp
+    : FormatInferAllTypesBaseOp<"format_infer_type_all_operands_and_types"> {
+  let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
+}
+
+// Test inferReturnTypes is called when allOperandTypes is true and there is one
+// ODS operand.
+def FormatInferTypeAllOperandsAndTypesOneOperandOp
+    : FormatInferAllTypesBaseOp<"format_infer_type_all_types_one_operand"> {
+  let assemblyFormat = "`(` $args `)` attr-dict `:` type(operands)";
+}
+
+// Test inferReturnTypes is called when allOperandTypes is true and there are
+// more than one ODS operands.
+def FormatInferTypeAllOperandsAndTypesTwoOperandsOp
+    : FormatInferAllTypesBaseOp<"format_infer_type_all_types_two_operands",
+                                [SameVariadicOperandSize]> {
+  let arguments = (ins Variadic<AnyType>:$args0, Variadic<AnyType>:$args1);
+  let assemblyFormat = "`(` $args0 `)` `(` $args1 `)` attr-dict `:` type(operands)";
+}
+
+// Test inferReturnTypes is called when allOperands is true and operand types
+// are separately specified.
+def FormatInferTypeAllTypesOp
+    : FormatInferAllTypesBaseOp<"format_infer_type_all_types"> {
+  let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)";
+}
+
 //===----------------------------------------------------------------------===//
 // Test SideEffects
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index c3214c7afab4d..c65d216c18b52 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -411,6 +411,18 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
 // CHECK: test.format_infer_type
 %ignored_res7 = test.format_infer_type
 
+// CHECK: test.format_infer_type_all_operands_and_types(%[[I64]], %[[I32]]) : i64, i32
+%ignored_res8:2 = test.format_infer_type_all_operands_and_types(%i64, %i32) : i64, i32
+
+// CHECK: test.format_infer_type_all_types_one_operand(%[[I64]], %[[I32]]) : i64, i32
+%ignored_res9:2 = test.format_infer_type_all_types_one_operand(%i64, %i32) : i64, i32
+
+// CHECK: test.format_infer_type_all_types_two_operands(%[[I64]], %[[I32]]) (%[[I64]], %[[I32]]) : i64, i32, i64, i32
+%ignored_res10:4 = test.format_infer_type_all_types_two_operands(%i64, %i32) (%i64, %i32) : i64, i32, i64, i32
+
+// CHECK: test.format_infer_type_all_types(%[[I64]], %[[I32]]) : i64, i32
+%ignored_res11:2 = test.format_infer_type_all_types(%i64, %i32) : i64, i32
+
 //===----------------------------------------------------------------------===//
 // Check DefaultValuedStrAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 4223c2d2ee585..6203edb213e60 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -424,14 +424,18 @@ struct OperationFormat {
   /// Generate the parser code for a specific format element.
   void genElementParser(Element *element, MethodBody &body,
                         FmtContext &attrTypeCtx);
-  /// Generate the c++ to resolve the types of operands and results during
+  /// Generate the C++ to resolve the types of operands and results during
   /// parsing.
   void genParserTypeResolution(Operator &op, MethodBody &body);
-  /// Generate the c++ to resolve regions during parsing.
+  /// Generate the C++ to resolve the types of the operands during parsing.
+  void genParserOperandTypeResolution(
+      Operator &op, MethodBody &body,
+      function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
+  /// Generate the C++ to resolve regions during parsing.
   void genParserRegionResolution(Operator &op, MethodBody &body);
-  /// Generate the c++ to resolve successors during parsing.
+  /// Generate the C++ to resolve successors during parsing.
   void genParserSuccessorResolution(Operator &op, MethodBody &body);
-  /// Generate the c++ to handling variadic segment size traits.
+  /// Generate the C++ to handling variadic segment size traits.
   void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
 
   /// Generate the operation printer from this format.
@@ -1462,17 +1466,25 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
     }
   }
 
+  // Emit the operand type resolutions.
+  genParserOperandTypeResolution(op, body, emitTypeResolver);
+
+  // Handle return type inference once all operands have been resolved
+  if (infersResultTypes)
+    body << formatv(inferReturnTypesParserCode, op.getCppClassName());
+}
+
+void OperationFormat::genParserOperandTypeResolution(
+    Operator &op, MethodBody &body,
+    function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) {
   // Early exit if there are no operands.
-  if (op.getNumOperands() == 0) {
-    // Handle return type inference here if there are no operands
-    if (infersResultTypes)
-      body << formatv(inferReturnTypesParserCode, op.getCppClassName());
+  if (op.getNumOperands() == 0)
     return;
-  }
 
-  // Handle the case where all operand types are in one group.
+  // Handle the case where all operand types are grouped together with
+  // "types(operands)".
   if (allOperandTypes) {
-    // If we have all operands together, use the full operand list directly.
+    // If `operands` was specified, use the full operand list directly.
     if (allOperands) {
       body << "  if (parser.resolveOperands(allOperands, allOperandTypes, "
               "allOperandLoc, result.operands))\n"
@@ -1496,7 +1508,8 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
          << "    return ::mlir::failure();\n";
     return;
   }
-  // Handle the case where all of the operands were grouped together.
+
+  // Handle the case where all operands are grouped together with "operands".
   if (allOperands) {
     body << "  if (parser.resolveOperands(allOperands, ";
 
@@ -1551,10 +1564,6 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
       body << ", " << operand.name << "OperandsLoc";
     body << ", result.operands))\n    return ::mlir::failure();\n";
   }
-
-  // Handle return type inference once all operands have been resolved
-  if (infersResultTypes)
-    body << formatv(inferReturnTypesParserCode, op.getCppClassName());
 }
 
 void OperationFormat::genParserRegionResolution(Operator &op,
@@ -1833,7 +1842,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
   // keyword.
   llvm::BitVector nonKeywordCases(cases.size());
   bool hasStrCase = false;
-  for (auto it : llvm::enumerate(cases)) {
+  for (auto &it : llvm::enumerate(cases)) {
     hasStrCase = it.value().isStrCase();
     if (!canFormatStringAsKeyword(it.value().getStr()))
       nonKeywordCases.set(it.index());
@@ -1860,7 +1869,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
   // overlap with other cases. For simplicity sake, only allow cases with a
   // single bit value.
   if (enumAttr.isBitEnum()) {
-    for (auto it : llvm::enumerate(cases)) {
+    for (auto &it : llvm::enumerate(cases)) {
       int64_t value = it.value().getValue();
       if (value < 0 || !llvm::isPowerOf2_64(value))
         nonKeywordCases.set(it.index());
@@ -1873,7 +1882,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
     body << "    switch (caseValue) {\n";
     StringRef cppNamespace = enumAttr.getCppNamespace();
     StringRef enumName = enumAttr.getEnumClassName();
-    for (auto it : llvm::enumerate(cases)) {
+    for (auto &it : llvm::enumerate(cases)) {
       if (nonKeywordCases.test(it.index()))
         continue;
       StringRef symbol = it.value().getSymbol();


        


More information about the Mlir-commits mailing list