[Mlir-commits] [mlir] c8598fa - [mlir] Add refineReturnTypes to InferTypeOpInterface

Jacques Pienaar llvmlistbot at llvm.org
Mon Jul 18 22:18:59 PDT 2022


Author: Jacques Pienaar
Date: 2022-07-18T22:18:52-07:00
New Revision: c8598fa22fdf587c00df6966800c8c6d3c62185d

URL: https://github.com/llvm/llvm-project/commit/c8598fa22fdf587c00df6966800c8c6d3c62185d
DIFF: https://github.com/llvm/llvm-project/commit/c8598fa22fdf587c00df6966800c8c6d3c62185d.diff

LOG: [mlir] Add refineReturnTypes to InferTypeOpInterface

refineReturnType method shares the same parameters as inferReturnTypes
but gets passed in the return types of the op if known that can be used
during refinement passes or for more op specific error reporting.
Currently the error reporting on failure is generic and doesn't allow
for specializing the returned result based on failure, with this change
what would previously have been a separate trait with specialized
verification can just be handled as part of inferrence rather than
duplicated.

refineReturnTypes behaves like inferReturnTypes if no result types are fed in,
while the current verification is recast as the default implementation for
refineReturnTypes with it calling inferReturnTypes (and so the default type
verification now goes through refine and allows for more op specific inference
mismatch errors).

Differential Revision: https://reviews.llvm.org/D129955

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/InferTypeOpInterface.td
    mlir/lib/Interfaces/InferTypeOpInterface.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/return-types.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 09bbba68be52..e1d9e5e120e1 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -47,6 +47,54 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
                     "::mlir::RegionRange":$regions,
                     "::llvm::SmallVectorImpl<::mlir::Type>&":$inferredReturnTypes)
     >,
+    StaticInterfaceMethod<
+      /*desc=*/[{Refine the return types that an op would generate.
+
+      This method computes the return types as `inferReturnTypes` does but
+      additionally takes the existing result types as input. The existing
+      result types can be checked as part of inference to provide more
+      op-specific error messages as well as part of inference to merge
+      additional information, attributes, during inference. It is called during
+      verification for ops implementing this trait with default behavior
+      reporting mismatch with current and inferred types printed.
+
+      The operands and attributes correspond to those with which an Operation
+      would be created (e.g., as used in Operation::create) and the regions of
+      the op. The method takes an optional location which, if set, will be used
+      to report errors on.
+
+      The return types may be elided or specific elements be null for elements
+      that should just be returned but not verified.
+
+      Be aware that this method is supposed to be called with valid arguments,
+      e.g., operands are verified, or it may result in an undefined behavior.
+      }],
+      /*retTy=*/"::mlir::LogicalResult",
+      /*methodName=*/"refineReturnTypes",
+      /*args=*/(ins "::mlir::MLIRContext *":$context,
+                    "::llvm::Optional<::mlir::Location>":$location,
+                    "::mlir::ValueRange":$operands,
+                    "::mlir::DictionaryAttr":$attributes,
+                    "::mlir::RegionRange":$regions,
+                    "::llvm::SmallVectorImpl<::mlir::Type>&":$returnTypes),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+          llvm::SmallVector<Type, 4> inferredReturnTypes;
+          if (failed(ConcreteOp::inferReturnTypes(context, location, operands,
+                                                  attributes, regions,
+                                                  inferredReturnTypes)))
+            return failure();
+          if (!ConcreteOp::isCompatibleReturnTypes(inferredReturnTypes,
+                                                   returnTypes)) {
+            return emitOptionalError(
+                location, "'", ConcreteOp::getOperationName(),
+                "' op inferred type(s) ", inferredReturnTypes,
+                " are incompatible with return type(s) of operation ",
+                returnTypes);
+          }
+          return success();
+      }]
+    >,
     StaticInterfaceMethod<
       /*desc=*/"Returns whether two array of types are compatible result types"
                " for an op.",

diff  --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 8f29682636f7..237aa0dedc51 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -204,17 +204,9 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
 }
 
 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
-  SmallVector<Type, 4> inferredReturnTypes;
+  SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
   auto retTypeFn = cast<InferTypeOpInterface>(op);
-  if (failed(retTypeFn.inferReturnTypes(
-          op->getContext(), op->getLoc(), op->getOperands(),
-          op->getAttrDictionary(), op->getRegions(), inferredReturnTypes)))
-    return failure();
-  if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes,
-                                         op->getResultTypes()))
-    return op->emitOpError("inferred type(s) ")
-           << inferredReturnTypes
-           << " are incompatible with return type(s) of operation "
-           << op->getResultTypes();
-  return success();
+  return retTypeFn.refineReturnTypes(op->getContext(), op->getLoc(),
+                                     op->getOperands(), op->getAttrDictionary(),
+                                     op->getRegions(), inferredReturnTypes);
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index efc95e9164dd..2e27663e2cd6 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1128,6 +1128,36 @@ LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
   return success();
 }
 
+// TODO: We should be able to only define either inferReturnType or
+// refineReturnType, currently only refineReturnType can be omitted.
+LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &returnTypes) {
+  returnTypes.clear();
+  return OpWithRefineTypeInterfaceOp::refineReturnTypes(
+      context, location, operands, attributes, regions, returnTypes);
+}
+
+LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
+    MLIRContext *, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &returnTypes) {
+  if (operands[0].getType() != operands[1].getType()) {
+    return emitOptionalError(location, "operand type mismatch ",
+                             operands[0].getType(), " vs ",
+                             operands[1].getType());
+  }
+  // TODO: Add helper to make this more concise to write.
+  if (returnTypes.empty())
+    returnTypes.resize(1, nullptr);
+  if (returnTypes[0] && returnTypes[0] != operands[0].getType())
+    return emitOptionalError(location,
+                             "required first operand and result to match");
+  returnTypes[0] = operands[0].getType();
+  return success();
+}
+
 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
     DictionaryAttr attributes, RegionRange regions,

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index b0a0cf3807fc..08f62b00e0e0 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -645,8 +645,14 @@ def IndexElementsAttrOp : TEST_Op<"indexElementsAttr"> {
 }
 
 def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
+    DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let arguments = (ins AnyTensor, AnyTensor);
+  let results = (outs AnyTensor);
+}
+
+def OpWithRefineTypeInterfaceOp : TEST_Op<"op_with_refine_type_if", [
     DeclareOpInterfaceMethods<InferTypeOpInterface,
-        ["inferReturnTypeComponents"]>]> {
+        ["refineReturnTypes"]>]> {
   let arguments = (ins AnyTensor, AnyTensor);
   let results = (outs AnyTensor);
 }

diff  --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir
index e2dea26c148d..cf859fe9e8f0 100644
--- a/mlir/test/mlir-tblgen/return-types.mlir
+++ b/mlir/test/mlir-tblgen/return-types.mlir
@@ -39,6 +39,13 @@ func.func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : ten
 
 // -----
 
+func.func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
+  // expected-error at +1 {{required first operand and result to match}}
+  %bad = "test.op_with_refine_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
+  return
+}
+
+// -----
 // CHECK-LABEL: testReifyFunctions
 func.func @testReifyFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
   // expected-remark at +1 {{arith.constant 10}}


        


More information about the Mlir-commits mailing list