[Mlir-commits] [mlir] e98a61d - [mlir][tosa] Add verifier check for Concat Op (#136047)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 24 02:07:56 PDT 2025


Author: Tai Ly
Date: 2025-04-24T10:07:53+01:00
New Revision: e98a61dc326c1b564461c0ae4fc693be5113d540

URL: https://github.com/llvm/llvm-project/commit/e98a61dc326c1b564461c0ae4fc693be5113d540
DIFF: https://github.com/llvm/llvm-project/commit/e98a61dc326c1b564461c0ae4fc693be5113d540.diff

LOG: [mlir][tosa] Add verifier check for Concat Op (#136047)

This adds verifier check for Concat Op
to make sure the sum of concatenated axis dimensions is equal to the
output's axis dimension

add tests in verifier.mlir
also moved existing concat verifier checks to verifier.mlir

Signed-off-by: Tai Ly <tai.ly at arm.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/invalid.mlir
    mlir/test/Dialect/Tosa/verifier.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c36c1074f5780..751ae785bda6f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1322,6 +1322,25 @@ LogicalResult tosa::ConcatOp::verify() {
                  << " on operands 0 and " << operandNum;
       }
     }
+
+    // ERROR_IF(axis_sum != shape[axis]);
+    int64_t axisSum = 0;
+    for (const auto &input : inputList) {
+      const ShapeAdaptor inputShape(input.getType());
+      if (inputShape.isDynamicDim(axis)) {
+        // make axisSum negative to indicate invalid value
+        axisSum = -1;
+        break;
+      }
+      axisSum += inputShape.getDimSize(axis);
+    }
+    const ShapeAdaptor outputShape(outType);
+    if (axisSum >= 0 && outputShape.hasRank() &&
+        !outputShape.isDynamicDim(axis) &&
+        axisSum != outputShape.getDimSize(axis))
+      return emitOpError("requires sum of axis dimensions of input1 "
+                         "equal to output axis dimension, got ")
+             << axisSum << " and " << outputShape.getDimSize(axis);
   }
 
   return success();

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 269ed58fdc81c..b147c94fde9b0 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -272,37 +272,6 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
 
 // -----
 
-func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
-  // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
-  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
-  return %0 : tensor<?x?xi8>
-}
-
-// -----
-
-func.func @test_concat_zero_inputs() {
-  // expected-error at +1 {{'tosa.concat' op expect at least one input}}
-  %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
-  // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
-  %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  return %0 : tensor<2x2xf32>
-}
-
-// -----
-
-func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
-  // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
-  %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  return %0 : tensor<2x2xf32>
-}
-
-// -----
-
 func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
   %pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
   // expected-error at +1 {{'tosa.pad' op shape operand is not compile time resolvable}}

diff  --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index fb8726cba1853..262e6d4265ea6 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -319,3 +319,42 @@ func.func @test_conv3d_wholly_divisible_output_width(%arg0: tensor<1x4x8x21x19xf
     : (tensor<1x4x8x21x19xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x19x34xf32>
   return %0 : tensor<1x4x8x19x34xf32>
 }
+
+// -----
+
+func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
+  // expected-error at +1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
+  return %0 : tensor<?x?xi8>
+}
+
+// -----
+
+func.func @test_concat_zero_inputs() {
+  // expected-error at +1 {{'tosa.concat' op expect at least one input}}
+  %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
+  %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  // expected-error at +1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<2x?xf32> {
+  // expected-error at +1 {{'tosa.concat' op requires sum of axis dimensions of input1 equal to output axis dimension, got 3 and 2}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
+  return %0 : tensor<2x?xf32>
+}


        


More information about the Mlir-commits mailing list