[Mlir-commits] [mlir] [mlir][tosa] Add additional input output dtype verifiers for the foll… (PR #127923)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 20 11:20:35 PST 2025


https://github.com/Jerry-Ge updated https://github.com/llvm/llvm-project/pull/127923

>From 5c2f63265e636702d6295bf0d31a5ea6a7284f46 Mon Sep 17 00:00:00 2001
From: Jerry Ge <jerry.ge at arm.com>
Date: Wed, 19 Feb 2025 14:13:43 -0800
Subject: [PATCH] [mlir][tosa] Add additional input output dtype verifiers for
 the following operators

- CastOp
- ConcatOp
- MatMulOp
- PadOp
- SliceOp
- TileOp
- ReshapeOp
- TransposeOp
- GatherOp
- ScatterOp
- MaxPool2dOp
- ReverseOp
- SelectOp

Change-Id: I1e8a1017f21f617443bc40bae42189915048c750
Co-authored-by: Tai Ly <tai.ly at arm.com>
Co-authored-by: Luke Hutton <luke.hutton at arm.com>
Signed-off-by: Jerry Ge <jerry.ge at arm.com>
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td |   9 +
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp         | 354 ++++++++++++++++++-
 mlir/test/Dialect/Tosa/invalid.mlir          |  23 +-
 3 files changed, 373 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7cdf79f4dc59d..a9f6f56532aeb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -248,6 +248,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   );
 
   let builders = [Tosa_MatMulOpQuantInfoBuilder];
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -277,6 +278,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
   );
 
   let hasCanonicalizer = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1200,6 +1202,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   );
   let hasCanonicalizeMethod = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let assemblyFormat = [{
     operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
@@ -1528,6 +1531,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1750,6 +1754,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
   let results = (outs
     Tosa_Tensor3D:$output
   );
+  
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1772,6 +1778,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
   let results = (outs
     Tosa_Tensor3D:$values_out
   );
+  
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1860,6 +1868,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d21e218308df7..2a8a6153d58d8 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -469,6 +469,104 @@ LogicalResult tosa::AvgPool2dOp::verify() {
   return emitOpError("input/output element types are incompatible.");
 }
 
+LogicalResult tosa::CastOp::verify() {
+  mlir::Type inputETy =
+      llvm::cast<ShapedType>(getInput().getType()).getElementType();
+  if (auto inputQuantType =
+          llvm::dyn_cast<mlir::quant::QuantizedType>(inputETy)) {
+    inputETy = inputQuantType.getStorageType();
+  }
+  mlir::Type outputETy =
+      llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+  if (auto outputQuantType =
+          llvm::dyn_cast<mlir::quant::QuantizedType>(outputETy)) {
+    outputETy = outputQuantType.getStorageType();
+  }
+
+  // input element type: bool
+  if (inputETy.isInteger(1)) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32)) {
+      return success();
+    }
+  }
+  // input element type: int8
+  if (inputETy.isInteger(8)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: int16
+  if (inputETy.isInteger(16)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+        outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: int32
+  if (inputETy.isInteger(32)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+        outputETy.isInteger(16) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: bf16 or fp16
+  if (inputETy.isBF16() || inputETy.isF16()) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+        llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: f8e4m3 or f8e5m2
+  if (llvm::isa<Float8E4M3FNType>(inputETy) ||
+      llvm::isa<Float8E5M2Type>(inputETy)) {
+    if (outputETy.isF16() || outputETy.isBF16() || outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: fp32
+  if (inputETy.isF32()) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+        llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF16() ||
+        outputETy.isBF16()) {
+      return success();
+    }
+  }
+
+  // following are outside of TOSA Spec
+
+  // allow casting to same type, for quatization/dequantization
+  if (inputETy == outputETy) {
+    return success();
+  }
+
+  // allow casting float to bool, for tosa_to_linalg testing
+  if (inputETy.isF32() && outputETy.isInteger(1)) {
+    return success();
+  }
+
+  // special case for I64
+  if (inputETy.isInteger(64) || outputETy.isInteger(64)) {
+    // be forgiving of casting to and from F64
+    return success();
+  }
+
+  // special case for fp64
+  if (inputETy.isF64() || outputETy.isF64()) {
+    // be forgiving of casting to and from F64
+    return success();
+  }
+
+  return emitOpError("input/output element types are incompatible: ")
+         << inputETy << " and " << outputETy;
+}
+
 LogicalResult tosa::ClampOp::verify() {
   mlir::Type inputETy =
       llvm::cast<ShapedType>(getInput().getType()).getElementType();
@@ -849,6 +947,71 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::ConcatOp::verify() {
+  // check that each input has same element type as output
+  auto outType = getOutput().getType();
+  const Operation::operand_range inputList = getInput1();
+
+  for (auto input : inputList) {
+    if (verifySameElementTypes(*this, /* inType = */ input.getType(), outType)
+            .failed()) {
+      return failure();
+    }
+  }
+
+  // Check there is at least one input
+  if (inputList.empty())
+    return emitOpError("expect at least one input");
+
+  const Type firstInputType = inputList.front().getType();
+  const ShapeAdaptor firstInputShape(firstInputType);
+  const int32_t axis = getAxis();
+
+  if (firstInputShape.hasRank()) {
+    // Check axis is in expected range
+    if (axis < 0 || axis >= firstInputShape.getRank())
+      return emitOpError("expect axis to be within range 0 < axis < "
+                         "rank(input1[0]), got ")
+             << axis;
+  }
+
+  const auto allOperandsHasRank = [](const Value input) {
+    return ShapeAdaptor(input.getType()).hasRank();
+  };
+  if (llvm::all_of(inputList, allOperandsHasRank)) {
+    const int64_t firstInputRank = firstInputShape.getRank();
+
+    for (const auto [index, input] : llvm::enumerate(inputList.drop_front())) {
+      const ShapeAdaptor inputShape(input.getType());
+      const int64_t inputRank = inputShape.getRank();
+      const size_t operandNum = index + 1;
+
+      // Check that each operand has the same rank
+      if (inputRank != firstInputRank)
+        return emitOpError(
+                   "expect all operands to have the same rank, but got ")
+               << firstInputRank << " vs " << inputRank << " on operands 0 and "
+               << operandNum;
+
+      // Check non-axis dims match
+      for (int i = 0; i < inputRank; i++) {
+        const int64_t inputDim = inputShape.getDimSize(i);
+        const int64_t firstInputDim = firstInputShape.getDimSize(i);
+        if (i == axis || firstInputShape.isDynamicDim(i) ||
+            inputShape.isDynamicDim(i))
+          continue;
+        if (inputDim != firstInputDim)
+          return emitOpError("expect all operand shapes to have the same sizes "
+                             "on non-axis dimensions, but got ")
+                 << inputDim << " vs " << firstInputDim << " at index " << i
+                 << " on operands 0 and " << operandNum;
+      }
+    }
+  }
+
+  return success();
+}
+
 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes,
@@ -898,6 +1061,107 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult MatMulOp::verify() {
+  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
+  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
+  auto resultEType =
+      llvm::cast<ShapedType>(getResult().getType()).getElementType();
+
+  // Must be shaped tensor types
+  if (!aType) {
+    emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
+    return failure();
+  }
+  if (!bType) {
+    emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
+    return failure();
+  }
+
+  auto aElementType = aType.getElementType();
+  auto bElementType = bType.getElementType();
+
+  auto aQuantizedEType =
+      llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
+  auto bQuantizedEType =
+      llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
+
+  if (aQuantizedEType || bQuantizedEType) {
+    if (!aQuantizedEType || !bQuantizedEType) {
+      emitOpError(
+          "expect operands to be both quantized or both not quantized, got ")
+          << aElementType << " and " << bElementType;
+      return failure();
+    }
+    // both a and b have quantized element types
+    auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
+    auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
+    if (aQuantWidth != bQuantWidth) {
+      emitOpError("expect quantized operands to have same widths, got ")
+          << aQuantWidth << " and " << bQuantWidth;
+      return failure();
+    }
+
+    if (aQuantWidth != 8 && aQuantWidth != 16) {
+      emitOpError("only support quantized types with width of 8 or 16, got ")
+          << aQuantWidth;
+      return failure();
+    }
+
+    // check result types
+    if (aQuantWidth == 8 && !resultEType.isInteger(32)) {
+      emitOpError("expect result element type to be i32, got ") << resultEType;
+      return failure();
+    }
+
+    // check result types
+    if (aQuantWidth == 16 && !resultEType.isInteger(48)) {
+      emitOpError("expect result element type to be i48, got ") << resultEType;
+      return failure();
+    }
+
+    return success();
+  }
+
+  // non-quantized element types
+
+  if (aElementType != bElementType) {
+    emitOpError("expect same element type for inputs a and b, got ")
+        << aElementType << " and " << bElementType;
+    return failure();
+  }
+  if (llvm::isa<Float8E5M2Type>(aElementType) ||
+      llvm::isa<Float8E4M3FNType>(aElementType)) {
+    if (!resultEType.isF16()) {
+      emitOpError("expect result element type to be f16, got ") << resultEType;
+      return failure();
+    }
+  }
+
+  if (aElementType.isInteger(8) && !resultEType.isInteger(32)) {
+    emitOpError("expect result element type to be i32, got ") << resultEType;
+    return failure();
+  }
+  if (aElementType.isInteger(16) && !resultEType.isInteger(48)) {
+    emitOpError("expect result element type to be i48, got ") << resultEType;
+    return failure();
+  }
+  if (aElementType.isF16() && !(resultEType.isF16() || resultEType.isF32())) {
+    emitOpError("expect result element type to be f16 or f32, got ")
+        << resultEType;
+    return failure();
+  }
+  if (aElementType.isBF16() && !resultEType.isF32()) {
+    emitOpError("expect result element type to be f32, got ") << resultEType;
+    return failure();
+  }
+  if (aElementType.isF32() && !resultEType.isF32()) {
+    emitOpError("expect result element type to be f32, got ") << resultEType;
+    return failure();
+  }
+
+  return success();
+}
+
 LogicalResult tosa::PadOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     PadOp::Adaptor adaptor,
@@ -946,6 +1210,18 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::PadOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
+  if (auto padConst = getPadConst()) {
+    if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
+                               /* outType = */ getOutput().getType())
+            .failed()) {
+      return failure();
+    }
+  }
   RankedTensorType inputType = getInput1().getType();
   RankedTensorType outputType = getOutput().getType();
   auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
@@ -1019,6 +1295,10 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::SliceOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed())
+    return failure();
   auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
   if (!inputType)
     return success();
@@ -1026,14 +1306,12 @@ LogicalResult tosa::SliceOp::verify() {
   auto startShapeRank =
       llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
   if (inputType.getRank() != startShapeRank)
-    return emitOpError(
-        "length of start attribute is not equal rank of input shape");
+    return emitOpError("length of start is not equal to rank of input shape");
 
   auto sizeShapeRank =
       llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
   if (inputType.getRank() != sizeShapeRank)
-    return emitOpError(
-        "length of size attribute is not equal rank of input shape");
+    return emitOpError("length of size is not equal to rank of input shape");
 
   return success();
 }
@@ -1238,6 +1516,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::TileOp::verify() {
+  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
   ShapedType outputType = llvm::cast<ShapedType>(getType());
 
@@ -1319,6 +1602,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
 }
 
 llvm::LogicalResult tosa::ReshapeOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   TensorType inputType = getInput1().getType();
   RankedTensorType outputType = getType();
 
@@ -1463,6 +1751,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::TransposeOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   TensorType inputType = getInput1().getType();
   TensorType permType = getPerms().getType();
   TensorType outputType = getOutput().getType();
@@ -1578,6 +1871,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::GatherOp::verify() {
+  return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+                                /* outType = */ getOutput().getType());
+}
+
 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ResizeOp::Adaptor adaptor,
@@ -1746,6 +2044,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::ScatterOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
+                             /* outType = */ getValuesOut().getType())
+          .failed() ||
+      verifySameElementTypes(*this, /* inType = */ getInput().getType(),
+                             /* outType = */ getValuesOut().getType())
+          .failed()) {
+    return failure();
+  }
+  return success();
+}
+
 static LogicalResult ReduceInferReturnTypes(
     ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2110,6 +2420,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
                                  inferredReturnShapes);
 }
 
+LogicalResult MaxPool2dOp::verify() {
+  return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+                                /* outType = */ getOutput().getType());
+}
+
 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     DepthwiseConv2DOp::Adaptor adaptor,
@@ -2412,6 +2727,10 @@ void IfOp::print(OpAsmPrinter &p) {
 }
 
 LogicalResult ReverseOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed())
+    return failure();
   TensorType inputType = getInput1().getType();
   TensorType outputType = getOutput().getType();
   int32_t reverseAxis = getAxis();
@@ -2440,6 +2759,33 @@ LogicalResult ReverseOp::verify() {
   return success();
 }
 
+LogicalResult tosa::SelectOp::verify() {
+  // verify input2 and input3 have same element type as output
+  if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed() ||
+      verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
+  // verify input1 has element type of bool
+  auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
+  if (!predicateType) {
+    emitOpError("expect shaped tensor for input1, got ")
+        << getInput1().getType();
+    return failure();
+  }
+  auto predicateElementType = predicateType.getElementType();
+  if (!predicateElementType.isInteger(1)) {
+    emitOpError("expect element type of bool for input1, got ")
+        << predicateElementType;
+    return failure();
+  }
+
+  return success();
+}
+
 // parse and print of WhileOp refer to the implementation of SCF dialect.
 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 1307da88d1e64..4b3525aa005bc 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -165,8 +165,7 @@ func.func @test_conv2d_quant_any_result(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7
 // -----
 
 func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
-  // expected-error at +2 {{failed to infer returned types}}
-  // expected-error at +1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}}
+  // expected-error at +1 {{'tosa.concat' op expect all operand shapes to have the same sizes on non-axis dimensions, but got 2 vs 1 at index 1 on operands 0 and 1}}
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
@@ -174,8 +173,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
 // -----
 
 func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
-  // expected-error at +2 {{failed to infer returned types}}
-  // expected-error at +1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
+  // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
   return %0 : tensor<?x?xi8>
 }
@@ -208,6 +206,14 @@ func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
 
 // -----
 
+func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2x3xf32> {
+  // expected-error at +1 {{'tosa.concat' op expect all operands to have the same rank, but got 3 vs 2 on operands 0 and 1}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2x3xf32>, tensor<1x2xf32>) -> tensor<2x2x3xf32>
+  return %0 : tensor<2x2x3xf32>
+}
+
+// -----
+
 func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
   %0 = tosa.const_shape {value = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
   // expected-error at +1 {{'tosa.pad' op expected padding tensor dim 0 to have size 4 (2*rank(shape1)) but got size 6}}
@@ -425,8 +431,7 @@ func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
 
 func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
   %1 = tosa.const_shape {value = dense<[13, 21, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
-  // expected-error at +2 {{failed to infer returned types}}
-  // expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
+  // expected-error at +1 {{'tosa.reshape' op expect input and output to have same element type, got 'f32' and 'i32'}}
   %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3x1xi32>
   return
 }
@@ -516,7 +521,7 @@ func.func @test_reshape_invalid_tensor_dim(%arg0 : tensor<4x?xf32>) -> () {
 
 func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
   // expected-error at +1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
-  %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+  %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
   return
 }
 
@@ -625,7 +630,7 @@ func.func @test_slice_invalid_start() {
   %0 = tensor.empty() : tensor<4x31x31xf32>
   %start = tosa.const_shape {value = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
   %size = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
-  // expected-error at +1 {{'tosa.slice' op length of start attribute is not equal rank of input shape}}
+  // expected-error at +1 {{'tosa.slice' op length of start is not equal to rank of input shape}}
   %3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<2>, !tosa.shape<3>) -> tensor<*xf32>
   return
 }
@@ -636,7 +641,7 @@ func.func @test_slice_invalid_size() {
   %0 = tensor.empty() : tensor<4x31x31xf32>
   %start = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
   %size = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
-  // expected-error at +1 {{'tosa.slice' op length of size attribute is not equal rank of input shape}}
+  // expected-error at +1 {{'tosa.slice' op length of size is not equal to rank of input shape}}
   %3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<3>, !tosa.shape<1>) -> tensor<*xf32>
   return
 }



More information about the Mlir-commits mailing list