[Mlir-commits] [mlir] e98a61d - [mlir][tosa] Add verifier check for Concat Op (#136047)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 24 02:07:56 PDT 2025
Author: Tai Ly
Date: 2025-04-24T10:07:53+01:00
New Revision: e98a61dc326c1b564461c0ae4fc693be5113d540
URL: https://github.com/llvm/llvm-project/commit/e98a61dc326c1b564461c0ae4fc693be5113d540
DIFF: https://github.com/llvm/llvm-project/commit/e98a61dc326c1b564461c0ae4fc693be5113d540.diff
LOG: [mlir][tosa] Add verifier check for Concat Op (#136047)
This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions is equal to the
output's axis dimension
add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir
Signed-off-by: Tai Ly <tai.ly at arm.com>
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/invalid.mlir
mlir/test/Dialect/Tosa/verifier.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c36c1074f5780..751ae785bda6f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1322,6 +1322,25 @@ LogicalResult tosa::ConcatOp::verify() {
<< " on operands 0 and " << operandNum;
}
}
+
+ // ERROR_IF(axis_sum != shape[axis]);
+ int64_t axisSum = 0;
+ for (const auto &input : inputList) {
+ const ShapeAdaptor inputShape(input.getType());
+ if (inputShape.isDynamicDim(axis)) {
+ // make axisSum negative to indicate invalid value
+ axisSum = -1;
+ break;
+ }
+ axisSum += inputShape.getDimSize(axis);
+ }
+ const ShapeAdaptor outputShape(outType);
+ if (axisSum >= 0 && outputShape.hasRank() &&
+ !outputShape.isDynamicDim(axis) &&
+ axisSum != outputShape.getDimSize(axis))
+ return emitOpError("requires sum of axis dimensions of input1 "
+ "equal to output axis dimension, got ")
+ << axisSum << " and " << outputShape.getDimSize(axis);
}
return success();
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 269ed58fdc81c..b147c94fde9b0 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -272,37 +272,6 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
// -----
-func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
- // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
- %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
- return %0 : tensor<?x?xi8>
-}
-
-// -----
-
-func.func @test_concat_zero_inputs() {
- // expected-error at +1 {{'tosa.concat' op expect at least one input}}
- %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
- // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
- %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
- return %0 : tensor<2x2xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
- // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
- %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
- return %0 : tensor<2x2xf32>
-}
-
-// -----
-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error at +1 {{'tosa.pad' op shape operand is not compile time resolvable}}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index fb8726cba1853..262e6d4265ea6 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -319,3 +319,42 @@ func.func @test_conv3d_wholly_divisible_output_width(%arg0: tensor<1x4x8x21x19xf
: (tensor<1x4x8x21x19xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x19x34xf32>
return %0 : tensor<1x4x8x19x34xf32>
}
+
+// -----
+
+func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
+ // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
+ return %0 : tensor<?x?xi8>
+}
+
+// -----
+
+func.func @test_concat_zero_inputs() {
+ // expected-error at +1 {{'tosa.concat' op expect at least one input}}
+ %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
+ %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<2x?xf32> {
+ // expected-error at +1 {{'tosa.concat' op requires sum of axis dimensions of input1 equal to output axis dimension, got 3 and 2}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
+ return %0 : tensor<2x?xf32>
+}
More information about the Mlir-commits
mailing list