[Mlir-commits] [mlir] [mlir][tosa] Add additional input output dtype verifiers for the foll… (PR #127923)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 21 13:26:12 PST 2025
https://github.com/Jerry-Ge updated https://github.com/llvm/llvm-project/pull/127923
>From 2fe9dd3ba554301f7920f7f2bc6e8274b95e84fa Mon Sep 17 00:00:00 2001
From: Jerry Ge <jerry.ge at arm.com>
Date: Wed, 19 Feb 2025 14:13:43 -0800
Subject: [PATCH] [mlir][tosa] Add additional input output dtype verifiers for
the following operators
- CastOp
- ConcatOp
- MatMulOp
- PadOp
- SliceOp
- TileOp
- ReshapeOp
- TransposeOp
- GatherOp
- ScatterOp
- MaxPool2dOp
- ReverseOp
- SelectOp
For ConcatOp specifically, this commit also enhances the verifier by
checking 4 another conditions:
- The input list is not empty
- The axis value is within range of the input shapes
- All inputs have the same rank
- All non concatenate axis dims have the same value
Change-Id: I1e8a1017f21f617443bc40bae42189915048c750
Co-authored-by: Tai Ly <tai.ly at arm.com>
Co-authored-by: Luke Hutton <luke.hutton at arm.com>
Signed-off-by: Jerry Ge <jerry.ge at arm.com>
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 9 +
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 354 ++++++++++++++++++-
mlir/test/Dialect/Tosa/invalid.mlir | 23 +-
3 files changed, 373 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 69a408767b3c6..007f98feb99ec 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -282,6 +282,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
];
let builders = [Tosa_MatMulOpQuantInfoBuilder];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -316,6 +317,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
];
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1421,6 +1423,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
let hasCanonicalizeMethod = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let assemblyFormat = [{
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
@@ -1796,6 +1799,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -2052,6 +2056,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -2079,6 +2085,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -2175,6 +2183,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e9c33e1b1bf10..b499ed41c2f5d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -472,6 +472,104 @@ LogicalResult tosa::AvgPool2dOp::verify() {
return emitOpError("input/output element types are incompatible.");
}
+LogicalResult tosa::CastOp::verify() {
+ mlir::Type inputETy =
+ llvm::cast<ShapedType>(getInput().getType()).getElementType();
+ if (auto inputQuantType =
+ llvm::dyn_cast<mlir::quant::QuantizedType>(inputETy)) {
+ inputETy = inputQuantType.getStorageType();
+ }
+ mlir::Type outputETy =
+ llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+ if (auto outputQuantType =
+ llvm::dyn_cast<mlir::quant::QuantizedType>(outputETy)) {
+ outputETy = outputQuantType.getStorageType();
+ }
+
+ // input element type: bool
+ if (inputETy.isInteger(1)) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32)) {
+ return success();
+ }
+ }
+ // input element type: int8
+ if (inputETy.isInteger(8)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: int16
+ if (inputETy.isInteger(16)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+ outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: int32
+ if (inputETy.isInteger(32)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+ outputETy.isInteger(16) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: bf16 or fp16
+ if (inputETy.isBF16() || inputETy.isF16()) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+ llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: f8e4m3 or f8e5m2
+ if (llvm::isa<Float8E4M3FNType>(inputETy) ||
+ llvm::isa<Float8E5M2Type>(inputETy)) {
+ if (outputETy.isF16() || outputETy.isBF16() || outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: fp32
+ if (inputETy.isF32()) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+ llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF16() ||
+ outputETy.isBF16()) {
+ return success();
+ }
+ }
+
+ // following are outside of TOSA Spec
+
+ // allow casting to same type, for quatization/dequantization
+ if (inputETy == outputETy) {
+ return success();
+ }
+
+ // allow casting float to bool, for tosa_to_linalg testing
+ if (inputETy.isF32() && outputETy.isInteger(1)) {
+ return success();
+ }
+
+ // special case for I64
+ if (inputETy.isInteger(64) || outputETy.isInteger(64)) {
+ // be forgiving of casting to and from F64
+ return success();
+ }
+
+ // special case for fp64
+ if (inputETy.isF64() || outputETy.isF64()) {
+ // be forgiving of casting to and from F64
+ return success();
+ }
+
+ return emitOpError("input/output element types are incompatible: ")
+ << inputETy << " and " << outputETy;
+}
+
LogicalResult tosa::ClampOp::verify() {
mlir::Type inputETy =
llvm::cast<ShapedType>(getInput().getType()).getElementType();
@@ -852,6 +950,71 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::ConcatOp::verify() {
+ // check that each input has same element type as output
+ auto outType = getOutput().getType();
+ const Operation::operand_range inputList = getInput1();
+
+ if (!llvm::all_of(inputList, [&](auto input) {
+ return succeeded(verifySameElementTypes(
+ *this, /* inType = */ input.getType(), outType));
+ })) {
+ return failure();
+ }
+
+ // Check there is at least one input
+ if (inputList.empty())
+ return emitOpError("expect at least one input");
+
+ const Type firstInputType = inputList.front().getType();
+ const ShapeAdaptor firstInputShape(firstInputType);
+ const int32_t axis = getAxis();
+
+ if (firstInputShape.hasRank()) {
+ // Check axis is in expected range
+ if (axis < 0 || axis >= firstInputShape.getRank())
+ return emitOpError("expect axis to be within range 0 < axis < "
+ "rank(input1[0]), got ")
+ << axis;
+ }
+
+ const auto allOperandsHasRank = [](const Value input) {
+ return ShapeAdaptor(input.getType()).hasRank();
+ };
+ if (llvm::all_of(inputList, allOperandsHasRank)) {
+ const int64_t firstInputRank = firstInputShape.getRank();
+
+ for (const auto [index, input] : llvm::enumerate(inputList.drop_front())) {
+ const ShapeAdaptor inputShape(input.getType());
+ const int64_t inputRank = inputShape.getRank();
+ const size_t operandNum = index + 1;
+
+ // Check that each operand has the same rank
+ if (inputRank != firstInputRank)
+ return emitOpError(
+ "expect all operands to have the same rank, but got ")
+ << firstInputRank << " vs " << inputRank << " on operands 0 and "
+ << operandNum;
+
+ // Check non-axis dims match
+ for (int i = 0; i < inputRank; i++) {
+ const int64_t inputDim = inputShape.getDimSize(i);
+ const int64_t firstInputDim = firstInputShape.getDimSize(i);
+ if (i == axis || firstInputShape.isDynamicDim(i) ||
+ inputShape.isDynamicDim(i))
+ continue;
+ if (inputDim != firstInputDim)
+ return emitOpError("expect all operand shapes to have the same sizes "
+ "on non-axis dimensions, but got ")
+ << inputDim << " vs " << firstInputDim << " at index " << i
+ << " on operands 0 and " << operandNum;
+ }
+ }
+ }
+
+ return success();
+}
+
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
@@ -901,6 +1064,107 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
return success();
}
+LogicalResult MatMulOp::verify() {
+ auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
+ auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
+ auto resultEType =
+ llvm::cast<ShapedType>(getResult().getType()).getElementType();
+
+ // Must be shaped tensor types
+ if (!aType) {
+ emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
+ return failure();
+ }
+ if (!bType) {
+ emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
+ return failure();
+ }
+
+ auto aElementType = aType.getElementType();
+ auto bElementType = bType.getElementType();
+
+ auto aQuantizedEType =
+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
+ auto bQuantizedEType =
+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
+
+ if (aQuantizedEType || bQuantizedEType) {
+ if (!aQuantizedEType || !bQuantizedEType) {
+ emitOpError(
+ "expect operands to be both quantized or both not quantized, got ")
+ << aElementType << " and " << bElementType;
+ return failure();
+ }
+ // both a and b have quantized element types
+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
+ if (aQuantWidth != bQuantWidth) {
+ emitOpError("expect quantized operands to have same widths, got ")
+ << aQuantWidth << " and " << bQuantWidth;
+ return failure();
+ }
+
+ if (aQuantWidth != 8 && aQuantWidth != 16) {
+ emitOpError("only support quantized types with width of 8 or 16, got ")
+ << aQuantWidth;
+ return failure();
+ }
+
+ // check result types
+ if (aQuantWidth == 8 && !resultEType.isInteger(32)) {
+ emitOpError("expect result element type to be i32, got ") << resultEType;
+ return failure();
+ }
+
+ // check result types
+ if (aQuantWidth == 16 && !resultEType.isInteger(48)) {
+ emitOpError("expect result element type to be i48, got ") << resultEType;
+ return failure();
+ }
+
+ return success();
+ }
+
+ // non-quantized element types
+
+ if (aElementType != bElementType) {
+ emitOpError("expect same element type for inputs a and b, got ")
+ << aElementType << " and " << bElementType;
+ return failure();
+ }
+ if (llvm::isa<Float8E5M2Type>(aElementType) ||
+ llvm::isa<Float8E4M3FNType>(aElementType)) {
+ if (!resultEType.isF16()) {
+ emitOpError("expect result element type to be f16, got ") << resultEType;
+ return failure();
+ }
+ }
+
+ if (aElementType.isInteger(8) && !resultEType.isInteger(32)) {
+ emitOpError("expect result element type to be i32, got ") << resultEType;
+ return failure();
+ }
+ if (aElementType.isInteger(16) && !resultEType.isInteger(48)) {
+ emitOpError("expect result element type to be i48, got ") << resultEType;
+ return failure();
+ }
+ if (aElementType.isF16() && !(resultEType.isF16() || resultEType.isF32())) {
+ emitOpError("expect result element type to be f16 or f32, got ")
+ << resultEType;
+ return failure();
+ }
+ if (aElementType.isBF16() && !resultEType.isF32()) {
+ emitOpError("expect result element type to be f32, got ") << resultEType;
+ return failure();
+ }
+ if (aElementType.isF32() && !resultEType.isF32()) {
+ emitOpError("expect result element type to be f32, got ") << resultEType;
+ return failure();
+ }
+
+ return success();
+}
+
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
PadOp::Adaptor adaptor,
@@ -949,6 +1213,18 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
}
LogicalResult tosa::PadOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ if (auto padConst = getPadConst()) {
+ if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ }
RankedTensorType inputType = getInput1().getType();
RankedTensorType outputType = getOutput().getType();
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
@@ -1022,6 +1298,10 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
}
LogicalResult tosa::SliceOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed())
+ return failure();
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
if (!inputType)
return success();
@@ -1029,14 +1309,12 @@ LogicalResult tosa::SliceOp::verify() {
auto startShapeRank =
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
if (inputType.getRank() != startShapeRank)
- return emitOpError(
- "length of start attribute is not equal rank of input shape");
+ return emitOpError("length of start is not equal to rank of input shape");
auto sizeShapeRank =
llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
if (inputType.getRank() != sizeShapeRank)
- return emitOpError(
- "length of size attribute is not equal rank of input shape");
+ return emitOpError("length of size is not equal to rank of input shape");
return success();
}
@@ -1241,6 +1519,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
}
LogicalResult tosa::TileOp::verify() {
+ if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
ShapedType outputType = llvm::cast<ShapedType>(getType());
@@ -1322,6 +1605,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
}
llvm::LogicalResult tosa::ReshapeOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
TensorType inputType = getInput1().getType();
RankedTensorType outputType = getType();
@@ -1466,6 +1754,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}
LogicalResult tosa::TransposeOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
TensorType inputType = getInput1().getType();
TensorType permType = getPerms().getType();
TensorType outputType = getOutput().getType();
@@ -1581,6 +1874,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::GatherOp::verify() {
+ return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+ /* outType = */ getOutput().getType());
+}
+
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ResizeOp::Adaptor adaptor,
@@ -1749,6 +2047,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::ScatterOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
+ /* outType = */ getValuesOut().getType())
+ .failed() ||
+ verifySameElementTypes(*this, /* inType = */ getInput().getType(),
+ /* outType = */ getValuesOut().getType())
+ .failed()) {
+ return failure();
+ }
+ return success();
+}
+
static LogicalResult ReduceInferReturnTypes(
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2113,6 +2423,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
inferredReturnShapes);
}
+LogicalResult MaxPool2dOp::verify() {
+ return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+ /* outType = */ getOutput().getType());
+}
+
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2415,6 +2730,10 @@ void IfOp::print(OpAsmPrinter &p) {
}
LogicalResult ReverseOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed())
+ return failure();
TensorType inputType = getInput1().getType();
TensorType outputType = getOutput().getType();
int32_t reverseAxis = getAxis();
@@ -2443,6 +2762,33 @@ LogicalResult ReverseOp::verify() {
return success();
}
+LogicalResult tosa::SelectOp::verify() {
+ // verify input2 and input3 have same element type as output
+ if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
+ /* outType = */ getOutput().getType())
+ .failed() ||
+ verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ // verify input1 has element type of bool
+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
+ if (!predicateType) {
+ emitOpError("expect shaped tensor for input1, got ")
+ << getInput1().getType();
+ return failure();
+ }
+ auto predicateElementType = predicateType.getElementType();
+ if (!predicateElementType.isInteger(1)) {
+ emitOpError("expect element type of bool for input1, got ")
+ << predicateElementType;
+ return failure();
+ }
+
+ return success();
+}
+
// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 1aa8547cb2fdb..31a19df915f98 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -164,8 +164,7 @@ func.func @test_conv2d_quant_any_result(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7
// -----
func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
- // expected-error at +2 {{failed to infer returned types}}
- // expected-error at +1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}}
+ // expected-error at +1 {{'tosa.concat' op expect all operand shapes to have the same sizes on non-axis dimensions, but got 2 vs 1 at index 1 on operands 0 and 1}}
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
@@ -173,8 +172,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
// -----
func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
- // expected-error at +2 {{failed to infer returned types}}
- // expected-error at +1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
+ // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
return %0 : tensor<?x?xi8>
}
@@ -207,6 +205,14 @@ func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
// -----
+func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2x3xf32> {
+ // expected-error at +1 {{'tosa.concat' op expect all operands to have the same rank, but got 3 vs 2 on operands 0 and 1}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2x3xf32>, tensor<1x2xf32>) -> tensor<2x2x3xf32>
+ return %0 : tensor<2x2x3xf32>
+}
+
+// -----
+
func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
%0 = tosa.const_shape {value = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
// expected-error at +1 {{'tosa.pad' op expected padding tensor dim 0 to have size 4 (2*rank(shape1)) but got size 6}}
@@ -424,8 +430,7 @@ func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
%1 = tosa.const_shape {value = dense<[13, 21, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
- // expected-error at +2 {{failed to infer returned types}}
- // expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
+ // expected-error at +1 {{'tosa.reshape' op expect input and output to have same element type, got 'f32' and 'i32'}}
%0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3x1xi32>
return
}
@@ -515,7 +520,7 @@ func.func @test_reshape_invalid_tensor_dim(%arg0 : tensor<4x?xf32>) -> () {
func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
// expected-error at +1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
- %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+ %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
return
}
@@ -624,7 +629,7 @@ func.func @test_slice_invalid_start() {
%0 = tensor.empty() : tensor<4x31x31xf32>
%start = tosa.const_shape {value = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%size = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
- // expected-error at +1 {{'tosa.slice' op length of start attribute is not equal rank of input shape}}
+ // expected-error at +1 {{'tosa.slice' op length of start is not equal to rank of input shape}}
%3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<2>, !tosa.shape<3>) -> tensor<*xf32>
return
}
@@ -635,7 +640,7 @@ func.func @test_slice_invalid_size() {
%0 = tensor.empty() : tensor<4x31x31xf32>
%start = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%size = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
- // expected-error at +1 {{'tosa.slice' op length of size attribute is not equal rank of input shape}}
+ // expected-error at +1 {{'tosa.slice' op length of size is not equal to rank of input shape}}
%3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<3>, !tosa.shape<1>) -> tensor<*xf32>
return
}
More information about the Mlir-commits
mailing list