[Mlir-commits] [mlir] [mlir][tosa] Add verifier check for Concat Op (PR #136047)
Tai Ly
llvmlistbot at llvm.org
Wed Apr 16 15:39:14 PDT 2025
https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/136047
This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions is equal to the output's axis dimension
add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir
>From 695b02a5f5f40b4b634ef21fbab60bd829ab5fda Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Tue, 8 Apr 2025 16:30:57 +0000
Subject: [PATCH] [mlir][tosa] Add verifier check for Concat Op
This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions
is equal to the output's axis dimension
add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir
Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: I53e41ca3c1f4ee48997c510fee2c16ed912dfaa0
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 19 ++++++++++++++
mlir/test/Dialect/Tosa/invalid.mlir | 31 ----------------------
mlir/test/Dialect/Tosa/verifier.mlir | 39 ++++++++++++++++++++++++++++
3 files changed, 58 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8b4f6ef0d0980..d9e77dd3f3770 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1178,6 +1178,25 @@ LogicalResult tosa::ConcatOp::verify() {
<< " on operands 0 and " << operandNum;
}
}
+
+ // ERROR_IF(axis_sum != shape[axis]);
+ int64_t axis_sum = 0;
+ for (const auto &input : inputList) {
+ const ShapeAdaptor inputShape(input.getType());
+ if (inputShape.isDynamicDim(axis)) {
+ // make axis_sum negative to indicate invalid value
+ axis_sum = -1;
+ break;
+ }
+ axis_sum += inputShape.getDimSize(axis);
+ }
+ const ShapeAdaptor outputShape(outType);
+ if (axis_sum >= 0 && outputShape.hasRank() &&
+ !outputShape.isDynamicDim(axis) &&
+ axis_sum != outputShape.getDimSize(axis))
+ return emitOpError("requires sum of axis dimensions of input1 "
+ "equal to output axis dimension, got ")
+ << axis_sum << " and " << outputShape.getDimSize(axis);
}
return success();
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index fc98aa95ed5b3..1ff73bee3923d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -272,37 +272,6 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
// -----
-func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
- // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
- %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
- return %0 : tensor<?x?xi8>
-}
-
-// -----
-
-func.func @test_concat_zero_inputs() {
- // expected-error at +1 {{'tosa.concat' op expect at least one input}}
- %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
- // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
- %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
- return %0 : tensor<2x2xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
- // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
- %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
- return %0 : tensor<2x2xf32>
-}
-
-// -----
-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error at +1 {{'tosa.pad' op shape operand is not compile time resolvable}}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index efdd26a9346fb..e6310fee22479 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -167,3 +167,42 @@ func.func @test_scalar_slice(%arg0: tensor<f32>) -> tensor<f32> {
%2 = tosa.slice %arg0, %0, %1 : (tensor<f32>, !tosa.shape<0>, !tosa.shape<0>) -> tensor<f32>
return %2 : tensor<f32>
}
+
+// -----
+
+func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
+ // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
+ return %0 : tensor<?x?xi8>
+}
+
+// -----
+
+func.func @test_concat_zero_inputs() {
+ // expected-error at +1 {{'tosa.concat' op expect at least one input}}
+ %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
+ %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<2x?xf32> {
+ // expected-error at +1 {{'tosa.concat' op requires sum of axis dimensions of input1 equal to output axis dimension, got 3 and 2}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
+ return %0 : tensor<2x?xf32>
+}
More information about the Mlir-commits
mailing list