[Mlir-commits] [mlir] [mlir][tosa] Add more verifiers for the following operators (PR #127923)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 27 14:13:03 PST 2025
https://github.com/Jerry-Ge updated https://github.com/llvm/llvm-project/pull/127923
>From bb978ee15c57659815f1b7dc8b83af0a2b5e3e6c Mon Sep 17 00:00:00 2001
From: Jerry Ge <jerry.ge at arm.com>
Date: Wed, 19 Feb 2025 14:13:43 -0800
Subject: [PATCH] [mlir][tosa] Add more verifiers for the following operators
For ConcatOp this commit also enhances the verifier by
checking 4 another conditions:
- The input list is not empty
- The axis value is within range of the input shapes
- All inputs have the same rank
- All non concatenate axis dims have the same value
For MatmulOp:
- Checked input a, bs tensor type, element types
For the following operators, added the verifySameElementTypes check.
- PadOp
- SliceOp
- TileOp
- ReshapeOp
- TransposeOp
- GatherOp
- ScatterOp
- MaxPool2dOp
- ReverseOp
- SelectOp
Change-Id: I1e8a1017f21f617443bc40bae42189915048c750
Co-authored-by: Tai Ly <tai.ly at arm.com>
Co-authored-by: Luke Hutton <luke.hutton at arm.com>
Signed-off-by: Jerry Ge <jerry.ge at arm.com>
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 8 +
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 208 ++++++++++++++++++-
mlir/test/Dialect/Tosa/invalid.mlir | 46 +++-
3 files changed, 249 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..0e5df48fb9d15 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -310,6 +310,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
];
let builders = [Tosa_MatMulOpQuantInfoBuilder];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -344,6 +345,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
];
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1471,6 +1473,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
let hasCanonicalizeMethod = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let assemblyFormat = [{
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
@@ -1846,6 +1849,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -2102,6 +2106,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -2135,6 +2141,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7b50eceb081dd..cb198471818e7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -871,6 +871,71 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::ConcatOp::verify() {
+ // check that each input has same element type as output
+ auto outType = getOutput().getType();
+ const Operation::operand_range inputList = getInput1();
+
+ // Check there is at least one input
+ if (inputList.empty())
+ return emitOpError("expect at least one input");
+
+ if (!llvm::all_of(inputList, [&](auto input) {
+ return succeeded(verifySameElementTypes(
+ *this, /* inType = */ input.getType(), outType));
+ })) {
+ return failure();
+ }
+
+ const Type firstInputType = inputList.front().getType();
+ const ShapeAdaptor firstInputShape(firstInputType);
+ const int32_t axis = getAxis();
+
+ if (firstInputShape.hasRank()) {
+ // Check axis is in expected range
+ if (axis < 0 || axis >= firstInputShape.getRank())
+ return emitOpError("expect axis to be within range 0 < axis < "
+ "rank(input1[0]), got ")
+ << axis;
+ }
+
+ const auto allOperandsHasRank = [](const Value input) {
+ return ShapeAdaptor(input.getType()).hasRank();
+ };
+ if (llvm::all_of(inputList, allOperandsHasRank)) {
+ const int64_t firstInputRank = firstInputShape.getRank();
+
+ for (const auto [index, input] : llvm::enumerate(inputList.drop_front())) {
+ const ShapeAdaptor inputShape(input.getType());
+ const int64_t inputRank = inputShape.getRank();
+ const size_t operandNum = index + 1;
+
+ // Check that each operand has the same rank
+ if (inputRank != firstInputRank)
+ return emitOpError(
+ "expect all operands to have the same rank, but got ")
+ << firstInputRank << " vs " << inputRank << " on operands 0 and "
+ << operandNum;
+
+ // Check non-axis dims match
+ for (int i = 0; i < inputRank; i++) {
+ const int64_t inputDim = inputShape.getDimSize(i);
+ const int64_t firstInputDim = firstInputShape.getDimSize(i);
+ if (i == axis || firstInputShape.isDynamicDim(i) ||
+ inputShape.isDynamicDim(i))
+ continue;
+ if (inputDim != firstInputDim)
+ return emitOpError("expect all operand shapes to have the same sizes "
+ "on non-axis dimensions, but got ")
+ << inputDim << " vs " << firstInputDim << " at index " << i
+ << " on operands 0 and " << operandNum;
+ }
+ }
+ }
+
+ return success();
+}
+
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
@@ -920,6 +985,57 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
return success();
}
+LogicalResult MatMulOp::verify() {
+ auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
+ auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
+
+ // Must be shaped tensor types
+ if (!aType) {
+ emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
+ return failure();
+ }
+ if (!bType) {
+ emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
+ return failure();
+ }
+
+ auto aElementType = aType.getElementType();
+ auto bElementType = bType.getElementType();
+
+ auto aQuantizedEType =
+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
+ auto bQuantizedEType =
+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
+
+ if (aQuantizedEType || bQuantizedEType) {
+ if (!aQuantizedEType || !bQuantizedEType) {
+ emitOpError(
+ "expect operands to be both quantized or both not quantized, got ")
+ << aElementType << " and " << bElementType;
+ return failure();
+ }
+ // both a and b have quantized element types
+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
+ if (aQuantWidth != bQuantWidth) {
+ emitOpError("expect quantized operands to have same widths, got ")
+ << aQuantWidth << " and " << bQuantWidth;
+ return failure();
+ }
+
+ return success();
+ }
+
+ // non-quantized element types
+ if (aElementType != bElementType) {
+ emitOpError("expect same element type for inputs a and b, got ")
+ << aElementType << " and " << bElementType;
+ return failure();
+ }
+
+ return success();
+}
+
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
PadOp::Adaptor adaptor,
@@ -968,6 +1084,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
}
LogicalResult tosa::PadOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+
+ if (auto padConst = getPadConst()) {
+ if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ }
+
RankedTensorType inputType = getInput1().getType();
RankedTensorType outputType = getOutput().getType();
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
@@ -1041,6 +1171,10 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
}
LogicalResult tosa::SliceOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed())
+ return failure();
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
if (!inputType)
return success();
@@ -1048,14 +1182,12 @@ LogicalResult tosa::SliceOp::verify() {
auto startShapeRank =
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
if (inputType.getRank() != startShapeRank)
- return emitOpError(
- "length of start attribute is not equal rank of input shape");
+ return emitOpError("length of start is not equal to rank of input shape");
auto sizeShapeRank =
llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
if (inputType.getRank() != sizeShapeRank)
- return emitOpError(
- "length of size attribute is not equal rank of input shape");
+ return emitOpError("length of size is not equal to rank of input shape");
return success();
}
@@ -1260,6 +1392,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
}
LogicalResult tosa::TileOp::verify() {
+ if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
ShapedType outputType = llvm::cast<ShapedType>(getType());
@@ -1341,6 +1478,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
}
llvm::LogicalResult tosa::ReshapeOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
TensorType inputType = getInput1().getType();
RankedTensorType outputType = getType();
@@ -1528,6 +1670,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}
LogicalResult tosa::TransposeOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
TensorType inputType = getInput1().getType();
TensorType outputType = getOutput().getType();
const llvm::ArrayRef<int32_t> constantPerms = getPerms();
@@ -1628,6 +1775,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::GatherOp::verify() {
+ return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+ /* outType = */ getOutput().getType());
+}
+
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ResizeOp::Adaptor adaptor,
@@ -1789,6 +1941,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::ScatterOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
+ /* outType = */ getValuesOut().getType())
+ .failed() ||
+ verifySameElementTypes(*this, /* inType = */ getInput().getType(),
+ /* outType = */ getValuesOut().getType())
+ .failed()) {
+ return failure();
+ }
+ return success();
+}
+
static LogicalResult ReduceInferReturnTypes(
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2244,6 +2408,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
inferredReturnShapes);
}
+LogicalResult MaxPool2dOp::verify() {
+ return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+ /* outType = */ getOutput().getType());
+}
+
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2546,6 +2715,10 @@ void IfOp::print(OpAsmPrinter &p) {
}
LogicalResult ReverseOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed())
+ return failure();
TensorType inputType = getInput1().getType();
TensorType outputType = getOutput().getType();
int32_t reverseAxis = getAxis();
@@ -2574,6 +2747,33 @@ LogicalResult ReverseOp::verify() {
return success();
}
+LogicalResult tosa::SelectOp::verify() {
+ // verify input2 and input3 have same element type as output
+ if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
+ /* outType = */ getOutput().getType())
+ .failed() ||
+ verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ // verify input1 has element type of bool
+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
+ if (!predicateType) {
+ emitOpError("expect shaped tensor for input1, got ")
+ << getInput1().getType();
+ return failure();
+ }
+ auto predicateElementType = predicateType.getElementType();
+ if (!predicateElementType.isInteger(1)) {
+ emitOpError("expect element type of bool for input1, got ")
+ << predicateElementType;
+ return failure();
+ }
+
+ return success();
+}
+
// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 123c65e1b4fcd..d0d98af38abc4 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -193,8 +193,7 @@ func.func @test_conv2d_quant_any_result(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7
// -----
func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
- // expected-error at +2 {{failed to infer returned types}}
- // expected-error at +1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}}
+ // expected-error at +1 {{'tosa.concat' op expect all operand shapes to have the same sizes on non-axis dimensions, but got 2 vs 1 at index 1 on operands 0 and 1}}
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
@@ -202,14 +201,36 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
// -----
func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
- // expected-error at +2 {{failed to infer returned types}}
- // expected-error at +1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
+ // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
return %0 : tensor<?x?xi8>
}
// -----
+func.func @test_concat_zero_inputs() {
+ // expected-error at +1 {{'tosa.concat' op expect at least one input}}
+ %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[0]), got -1}}
+ %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[0]), got 3}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
// expected-error at +1 {{'tosa.pad' op shape operand is not compile time resolvable}}
%0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, !tosa.shape<6>) -> tensor<13x21x3xf32>
@@ -236,6 +257,14 @@ func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
// -----
+func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2x3xf32> {
+ // expected-error at +1 {{'tosa.concat' op expect all operands to have the same rank, but got 3 vs 2 on operands 0 and 1}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2x3xf32>, tensor<1x2xf32>) -> tensor<2x2x3xf32>
+ return %0 : tensor<2x2x3xf32>
+}
+
+// -----
+
func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
%0 = tosa.const_shape {value = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
// expected-error at +1 {{'tosa.pad' op expected padding tensor dim 0 to have size 4 (2*rank(shape1)) but got size 6}}
@@ -430,8 +459,7 @@ func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
%1 = tosa.const_shape {value = dense<[13, 21, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
- // expected-error at +2 {{failed to infer returned types}}
- // expected-error at +1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
+ // expected-error at +1 {{'tosa.reshape' op expect input and output to have same element type, got 'f32' and 'i32'}}
%0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3x1xi32>
return
}
@@ -521,7 +549,7 @@ func.func @test_reshape_invalid_tensor_dim(%arg0 : tensor<4x?xf32>) -> () {
func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
// expected-error at +1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
- %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
+ %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
return
}
@@ -634,7 +662,7 @@ func.func @test_slice_invalid_start() {
%0 = tensor.empty() : tensor<4x31x31xf32>
%start = tosa.const_shape {value = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%size = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
- // expected-error at +1 {{'tosa.slice' op length of start attribute is not equal rank of input shape}}
+ // expected-error at +1 {{'tosa.slice' op length of start is not equal to rank of input shape}}
%3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<2>, !tosa.shape<3>) -> tensor<*xf32>
return
}
@@ -645,7 +673,7 @@ func.func @test_slice_invalid_size() {
%0 = tensor.empty() : tensor<4x31x31xf32>
%start = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%size = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
- // expected-error at +1 {{'tosa.slice' op length of size attribute is not equal rank of input shape}}
+ // expected-error at +1 {{'tosa.slice' op length of size is not equal to rank of input shape}}
%3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<3>, !tosa.shape<1>) -> tensor<*xf32>
return
}
More information about the Mlir-commits
mailing list