[Mlir-commits] [mlir] 1aa0b84 - [mlir][ods] Fix OpFormatGen calling inferReturnTypes before region/segment resolution

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 16 11:04:54 PST 2021


Author: Mogball
Date: 2021-12-16T19:04:50Z
New Revision: 1aa0b84fa468076262e18d5071d284cac5b402a3

URL: https://github.com/llvm/llvm-project/commit/1aa0b84fa468076262e18d5071d284cac5b402a3
DIFF: https://github.com/llvm/llvm-project/commit/1aa0b84fa468076262e18d5071d284cac5b402a3.diff

LOG: [mlir][ods] Fix OpFormatGen calling inferReturnTypes before region/segment resolution

The generated parser for ops with type inference calls `inferReturnTypes` before region resolution and segment attribute resolution, i.e. regions and the segment attributes are not passed to the `inferReturnTypes` even though it may need that information.

In particular, an op that has sized operand segments which queries those operands in its `inferReturnTypes` function will crash because the segment attributes hadn't been added yet.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D115782

Added: 
    

Modified: 
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 627fac4c11874..bcb6e1e590f60 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2150,6 +2150,48 @@ def FormatInferTypeAllTypesOp
   let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)";
 }
 
+// Test inferReturnTypes coupled with regions.
+def FormatInferTypeRegionsOp
+    : TEST_Op<"format_infer_type_regions", [InferTypeOpInterface]> {
+  let results = (outs Variadic<AnyType>:$outs);
+  let regions = (region AnyRegion:$region);
+  let assemblyFormat = "$region attr-dict";
+  let extraClassDeclaration = [{
+    static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+          ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+          ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+          ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+      if (regions.empty())
+        return ::mlir::failure();
+      auto types = regions.front()->getArgumentTypes();
+      inferredReturnTypes.assign(types.begin(), types.end());
+      return ::mlir::success();
+    }
+  }];
+}
+
+// Test inferReturnTypes coupled with variadic operands (operand_segment_sizes).
+def FormatInferTypeVariadicOperandsOp
+    : TEST_Op<"format_infer_type_variadic_operands",
+              [InferTypeOpInterface, AttrSizedOperandSegments]> {
+  let arguments = (ins Variadic<I32>:$a, Variadic<I64>:$b);
+  let results = (outs Variadic<AnyType>:$outs);
+  let assemblyFormat = "`(` $a `:` type($a) `)` `(` $b `:` type($b) `)` attr-dict";
+  let extraClassDeclaration = [{
+    static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+          ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+          ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+          ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+      FormatInferTypeVariadicOperandsOpAdaptor adaptor(operands, attributes);
+      auto aTypes = adaptor.getA().getTypes();
+      auto bTypes = adaptor.getB().getTypes();
+      inferredReturnTypes.append(aTypes.begin(), aTypes.end());
+      inferredReturnTypes.append(bTypes.begin(), bTypes.end());
+      return ::mlir::success();
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Test SideEffects
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index c65d216c18b52..152cd0a554f1a 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -423,6 +423,16 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
 // CHECK: test.format_infer_type_all_types(%[[I64]], %[[I32]]) : i64, i32
 %ignored_res11:2 = test.format_infer_type_all_types(%i64, %i32) : i64, i32
 
+// CHECK: test.format_infer_type_regions
+// CHECK-NEXT: ^bb0(%{{.*}}: {{.*}}, %{{.*}}: {{.*}}):
+%ignored_res12:2 = test.format_infer_type_regions {
+^bb0(%arg0: i32, %arg1: f32):
+  "test.terminator"() : () -> ()
+}
+
+// CHECK: test.format_infer_type_variadic_operands(%[[I32]], %[[I32]] : i32, i32) (%[[I64]], %[[I64]] : i64, i64)
+%ignored_res13:4 = test.format_infer_type_variadic_operands(%i32, %i32 : i32, i32) (%i64, %i64 : i64, i64)
+
 //===----------------------------------------------------------------------===//
 // Check DefaultValuedStrAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 6203edb213e60..2b5767e325679 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1185,10 +1185,10 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
 
   // Generate the code to resolve the operand/result types and successors now
   // that they have been parsed.
-  genParserTypeResolution(op, body);
   genParserRegionResolution(op, body);
   genParserSuccessorResolution(op, body);
   genParserVariadicSegmentResolution(op, body);
+  genParserTypeResolution(op, body);
 
   body << "  return ::mlir::success();\n";
 }


        


More information about the Mlir-commits mailing list