[Mlir-commits] [mlir] [mlir][tosa] Add verifier check for Concat Op (PR #136047)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 16 15:39:51 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Tai Ly (Tai78641)
<details>
<summary>Changes</summary>
This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions is equal to the output's axis dimension
add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir
---
Full diff: https://github.com/llvm/llvm-project/pull/136047.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+19)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (-31)
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+39)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8b4f6ef0d0980..d9e77dd3f3770 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1178,6 +1178,25 @@ LogicalResult tosa::ConcatOp::verify() {
<< " on operands 0 and " << operandNum;
}
}
+
+ // ERROR_IF(axis_sum != shape[axis]);
+ int64_t axis_sum = 0;
+ for (const auto &input : inputList) {
+ const ShapeAdaptor inputShape(input.getType());
+ if (inputShape.isDynamicDim(axis)) {
+ // make axis_sum negative to indicate invalid value
+ axis_sum = -1;
+ break;
+ }
+ axis_sum += inputShape.getDimSize(axis);
+ }
+ const ShapeAdaptor outputShape(outType);
+ if (axis_sum >= 0 && outputShape.hasRank() &&
+ !outputShape.isDynamicDim(axis) &&
+ axis_sum != outputShape.getDimSize(axis))
+ return emitOpError("requires sum of axis dimensions of input1 "
+ "equal to output axis dimension, got ")
+ << axis_sum << " and " << outputShape.getDimSize(axis);
}
return success();
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index fc98aa95ed5b3..1ff73bee3923d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -272,37 +272,6 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
// -----
-func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
- // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
- %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
- return %0 : tensor<?x?xi8>
-}
-
-// -----
-
-func.func @test_concat_zero_inputs() {
- // expected-error at +1 {{'tosa.concat' op expect at least one input}}
- %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
- // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
- %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
- return %0 : tensor<2x2xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
- // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
- %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
- return %0 : tensor<2x2xf32>
-}
-
-// -----
-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error at +1 {{'tosa.pad' op shape operand is not compile time resolvable}}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index efdd26a9346fb..e6310fee22479 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -167,3 +167,42 @@ func.func @test_scalar_slice(%arg0: tensor<f32>) -> tensor<f32> {
%2 = tosa.slice %arg0, %0, %1 : (tensor<f32>, !tosa.shape<0>, !tosa.shape<0>) -> tensor<f32>
return %2 : tensor<f32>
}
+
+// -----
+
+func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
+ // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
+ return %0 : tensor<?x?xi8>
+}
+
+// -----
+
+func.func @test_concat_zero_inputs() {
+ // expected-error at +1 {{'tosa.concat' op expect at least one input}}
+ %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
+ %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<2x?xf32> {
+ // expected-error at +1 {{'tosa.concat' op requires sum of axis dimensions of input1 equal to output axis dimension, got 3 and 2}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
+ return %0 : tensor<2x?xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/136047
More information about the Mlir-commits
mailing list