[Mlir-commits] [mlir] 1760d8b - [mlir][ODS] Support result type inference in custom assembly format

Daniel Resnick llvmlistbot at llvm.org
Mon Oct 11 13:09:23 PDT 2021


Author: Daniel Resnick
Date: 2021-10-11T14:07:56-06:00
New Revision: 1760d8b36b4804758a9a4801edc1d97c0ba4f25c

URL: https://github.com/llvm/llvm-project/commit/1760d8b36b4804758a9a4801edc1d97c0ba4f25c
DIFF: https://github.com/llvm/llvm-project/commit/1760d8b36b4804758a9a4801edc1d97c0ba4f25c.diff

LOG: [mlir][ODS] Support result type inference in custom assembly format

Operations that have the InferTypeOpInterface trait can now omit the return
types in their custom assembly formats.

Differential Revision: https://reviews.llvm.org/D111326

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format-spec.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 398121c8dfa0c..ec9e6fdc80ddb 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -929,6 +929,11 @@ these equal constraints to discern the types of missing variables. The currently
 supported traits are: `AllTypesMatch`, `TypesMatchWith`, `SameTypeOperands`, and
 `SameOperandsAndResultType`.
 
+*   InferTypeOpInterface
+
+Operations that implement `InferTypeOpInterface` can omit their result types in
+their assembly format since the result types can be inferred from the operands.
+
 ### `hasCanonicalizer`
 
 This boolean field indicate whether canonicalization patterns have been defined

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 858ce7df5de0a..a94ec6dced389 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2021,6 +2021,24 @@ def FormatTypesMatchContextOp : TEST_Op<"format_types_match_context", [
   let assemblyFormat = "attr-dict $value `:` type($value)";
 }
 
+//===----------------------------------------------------------------------===//
+// InferTypeOpInterface type inference in assembly format
+
+def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
+  let results = (outs AnyType);
+  let assemblyFormat = "attr-dict";
+
+  let extraClassDeclaration = [{
+    static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+          ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+          ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+          ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+      inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
+      return ::mlir::success();
+    }
+   }];
+}
+
 //===----------------------------------------------------------------------===//
 // Test SideEffects
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 8c6bb09f34a37..e4d0f9e3b21f8 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -3,6 +3,7 @@
 // This file contains tests for the specification of the declarative op format.
 
 include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 
 def TestDialect : Dialect {
   let name = "test";
@@ -566,4 +567,6 @@ def ZCoverageValidH : TestFormat_Op<[{
   operands type($result) attr-dict
 }], [AllTypesMatch<["operand", "result"]>]>,
      Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;
-
+def ZCoverageValidI : TestFormat_Op<[{
+  operands type(operands) attr-dict
+}], [InferTypeOpInterface]>, Arguments<(ins Variadic<I64>:$inputs)>, Results<(outs I64:$result)>;

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index ccaedba466597..4be71a9eee1f7 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -354,3 +354,10 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
 
 // CHECK: test.format_types_match_context %[[I64]] : i64
 %ignored_res6 = test.format_types_match_context %i64 : i64
+
+//===----------------------------------------------------------------------===//
+// InferTypeOpInterface type inference
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_infer_type
+%ignored_res7 = test.format_infer_type

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index bffd161fb7ee3..708b9b145851c 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -441,7 +441,8 @@ struct OperationFormat {
   };
 
   OperationFormat(const Operator &op)
-      : allOperands(false), allOperandTypes(false), allResultTypes(false) {
+      : allOperands(false), allOperandTypes(false), allResultTypes(false),
+        infersResultTypes(false) {
     operandTypes.resize(op.getNumOperands(), TypeResolution());
     resultTypes.resize(op.getNumResults(), TypeResolution());
 
@@ -482,6 +483,9 @@ struct OperationFormat {
   /// contains these, it can not contain individual type resolvers.
   bool allOperands, allOperandTypes, allResultTypes;
 
+  /// A flag indicating if this operation infers its result types
+  bool infersResultTypes;
+
   /// A flag indicating if this operation has the SingleBlockImplicitTerminator
   /// trait.
   bool hasImplicitTermTrait;
@@ -682,6 +686,19 @@ const char *const functionalTypeParserCode = R"(
   {1}Types = {0}__{1}_functionType.getResults();
 )";
 
+/// The code snippet used to generate a parser call to infer return types.
+///
+/// {0}: The operation class name
+const char *const inferReturnTypesParserCode = R"(
+  ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
+  if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
+      result.location, result.operands,
+      result.attributes.getDictionary(parser.getContext()),
+      result.regions, inferredReturnTypes)))
+    return ::mlir::failure();
+  result.addTypes(inferredReturnTypes);
+)";
+
 /// The code snippet used to generate a parser call for a region list.
 ///
 /// {0}: The name for the region list.
@@ -1437,19 +1454,25 @@ void OperationFormat::genParserTypeResolution(Operator &op,
   };
 
   // Resolve each of the result types.
-  if (allResultTypes) {
-    body << "  result.addTypes(allResultTypes);\n";
-  } else {
-    for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
-      body << "  result.addTypes(";
-      emitTypeResolver(resultTypes[i], op.getResultName(i));
-      body << ");\n";
+  if (!infersResultTypes) {
+    if (allResultTypes) {
+      body << "  result.addTypes(allResultTypes);\n";
+    } else {
+      for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
+        body << "  result.addTypes(";
+        emitTypeResolver(resultTypes[i], op.getResultName(i));
+        body << ");\n";
+      }
     }
   }
 
   // Early exit if there are no operands.
-  if (op.getNumOperands() == 0)
+  if (op.getNumOperands() == 0) {
+    // Handle return type inference here if there are no operands
+    if (infersResultTypes)
+      body << formatv(inferReturnTypesParserCode, op.getCppClassName());
     return;
+  }
 
   // Handle the case where all operand types are in one group.
   if (allOperandTypes) {
@@ -1532,6 +1555,10 @@ void OperationFormat::genParserTypeResolution(Operator &op,
       body << ", " << operand.name << "OperandsLoc";
     body << ", result.operands))\n    return ::mlir::failure();\n";
   }
+
+  // Handle return type inference once all operands have been resolved
+  if (infersResultTypes)
+    body << formatv(inferReturnTypesParserCode, op.getCppClassName());
 }
 
 void OperationFormat::genParserRegionResolution(Operator &op,
@@ -2478,6 +2505,7 @@ class FormatParser {
   // during parsing.
   bool hasAttrDict = false;
   bool hasAllRegions = false, hasAllSuccessors = false;
+  bool canInferResultTypes = false;
   llvm::SmallBitVector seenOperandTypes, seenResultTypes;
   llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
   llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
@@ -2515,6 +2543,9 @@ LogicalResult FormatParser::parse() {
       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
     } else if (def.isSubClassOf("TypesMatchWith")) {
       handleTypesMatchConstraint(variableTyResolver, def);
+    } else if (def.getName() == "InferTypeOpInterface" &&
+               !op.allResultTypesKnown()) {
+      canInferResultTypes = true;
     }
   }
 
@@ -2684,6 +2715,14 @@ LogicalResult FormatParser::verifyResults(
   if (fmt.allResultTypes)
     return ::mlir::success();
 
+  // If no result types are specified and we can infer them, infer all result
+  // types
+  if (op.getNumResults() > 0 && seenResultTypes.count() == 0 &&
+      canInferResultTypes) {
+    fmt.infersResultTypes = true;
+    return ::mlir::success();
+  }
+
   // Check that all of the result types can be inferred.
   auto &buildableTypes = fmt.buildableTypes;
   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {


        


More information about the Mlir-commits mailing list