[Mlir-commits] [mlir] [mlir][tosa] Add error_if checks for Transpose (PR #135219)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 10 10:38:19 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

<details>
<summary>Changes</summary>

This adds missing error_if checking for Transpose Op 
also moved all transpose op's verifier tests from
invalid.mlir to verifier.mlir


---
Full diff: https://github.com/llvm/llvm-project/pull/135219.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+29-19) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (-112) 
- (added) mlir/test/Dialect/Tosa/verifier.mlir (+126) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5941be8403480..1ba2cda784463 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1964,23 +1964,28 @@ LogicalResult tosa::TransposeOp::verify() {
           .failed()) {
     return failure();
   }
-  TensorType inputType = getInput1().getType();
-  TensorType outputType = getOutput().getType();
+
+  const ShapeAdaptor inputShape(getInput1().getType());
+  const ShapeAdaptor outputShape(getOutput().getType());
+
   const llvm::ArrayRef<int32_t> constantPerms = getPerms();
 
-  if (inputType.hasRank() &&
-      constantPerms.size() != static_cast<size_t>(inputType.getRank()))
+  if (inputShape.hasRank() &&
+      constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
     return emitOpError() << "expected perms attribute to have size "
-                         << inputType.getRank() << " (input rank) but got size "
+                         << inputShape.getRank()
+                         << " (input rank) but got size "
                          << constantPerms.size();
-  if (inputType.hasRank() && outputType.hasRank() &&
-      inputType.getRank() != outputType.getRank())
+
+  if (inputShape.hasRank() && outputShape.hasRank() &&
+      inputShape.getRank() != outputShape.getRank())
     return emitOpError()
            << "expected input tensor rank to equal result tensor rank";
-  if (outputType.hasRank() &&
-      constantPerms.size() != static_cast<size_t>(outputType.getRank()))
+
+  if (outputShape.hasRank() &&
+      constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
     return emitOpError() << "expected perms attribute to have size "
-                         << outputType.getRank()
+                         << outputShape.getRank()
                          << " (output rank) but got size "
                          << constantPerms.size();
 
@@ -1993,22 +1998,27 @@ LogicalResult tosa::TransposeOp::verify() {
           constantPerms, [](int32_t v) -> int64_t { return v; }))))
     return emitOpError() << "expected valid permutation indices";
 
+  // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
+  if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
+      inputShape.getNumElements() != outputShape.getNumElements())
+    return emitOpError() << "expected input1 and output to have same numbers "
+                            "of elements, got "
+                         << inputShape.getNumElements() << " and "
+                         << outputShape.getNumElements();
+
   // Verify that the types of the input and output tensors are properly
   // permuted.
-  if (inputType.hasRank() && outputType.hasRank()) {
-    assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
-           inputType.getRank() == outputType.getRank());
-
-    for (auto i = 0; i < outputType.getRank(); i++) {
-      if (inputType.isDynamicDim(constantPerms[i]) ||
-          outputType.isDynamicDim(i))
+  if (inputShape.hasRank() && outputShape.hasRank()) {
+    for (auto i = 0; i < outputShape.getRank(); i++) {
+      if (inputShape.isDynamicDim(constantPerms[i]) ||
+          outputShape.isDynamicDim(i))
         continue;
 
-      if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
+      if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
         return emitOpError()
                << "expected output tensor dim " << i << " to match "
                << "input dim " << constantPerms[i] << " with value of "
-               << inputType.getDimSize(constantPerms[i]);
+               << inputShape.getDimSize(constantPerms[i]);
     }
   }
 
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 55a9fcb15bbc7..3310919d406a2 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -368,79 +368,6 @@ func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor
 
 // -----
 
-func.func @test_transpose_io_rank_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21x1xf32> {
-  // expected-error at +1 {{'tosa.transpose' op expected input tensor rank to equal result tensor rank}}
-  %0 = tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>}: (tensor<13x21x3xf32>) -> tensor<3x13x21x1xf32>
-  return %0 : tensor<3x13x21x1xf32>
-}
-
-// -----
-
-func.func @test_transpose_rank0_perms() {
-  %14 = tensor.empty() : tensor<5x27xi64>
-  // expected-error at +1 {{'tosa.transpose' op expected perms attribute to have size 2 (input rank) but got size 0}}
-  %72 = tosa.transpose %14 {perms = array<i32> }: (tensor<5x27xi64>) -> tensor<?x?xi64>
-  return
-}
-
-// -----
-
-func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
-  // expected-error at +1 {{'tosa.transpose' op expected perms attribute to have size 3 (input rank) but got size 7}}
-  %0 = tosa.transpose %arg0 {perms = array<i32: 6, 5, 4, 3, 2, 1, 0> }: (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
-  return %0 : tensor<3x13x21xf32>
-}
-
-// -----
-
-func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>) -> tensor<?x?x?xf32> {
-  // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
-  %0 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 0> }: (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
-  return %0 : tensor<?x?x?xf32>
-}
-
-// -----
-
-func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
-  // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
-  %1 = tosa.transpose %arg0 {perms = array<i32: -1, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
-  return %1 : tensor<*xi32>
-}
-
-// -----
-
-func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
-  // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
-  %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
-  return %1 : tensor<*xi32>
-}
-
-// -----
-
-func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> {
-  // expected-error at +1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}}
-  %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<3x2xi32>) -> tensor<3x4xi32>
-  return %1 : tensor<3x4xi32>
-}
-
-// -----
-
-func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> {
-  // expected-error at +1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}}
-  %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<2x?xi32>) -> tensor<3x4xi32>
-  return %1 : tensor<3x4xi32>
-}
-
-// -----
-
-func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> {
-  // expected-error at +1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}}
-  %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>} : (tensor<2x3xi32>) -> tensor<3x2xf32>
-  return %1 : tensor<3x2xf32>
-}
-
-// -----
-
 func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
   // expected-error at +2 {{failed to infer returned types}}
   // expected-error at +1 {{'tosa.reduce_sum' op inferred type(s) 'tensor<1x3x4x5xf32>' are incompatible with return type(s) of operation 'tensor<1x3x4x5xi32>'}}
@@ -783,37 +710,6 @@ func.func @test_tile_io_rank_mismatch() {
   return
 }
 
-// -----
-
-// CHECK-LABEL: @test_invalid_constant_permutation
-func.func @test_invalid_constant_permutation() {
-  %0 = tensor.empty() : tensor<3x4x5xi32>
-  // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
-  %2 = tosa.transpose %0 {perms = array<i32: 3, 0, 1>}: (tensor<3x4x5xi32>) -> tensor<3x4x5xi32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: test_rank_size_constant_permutation
-func.func @test_rank_size_constant_permutation() {
-  %0 = arith.constant 6 : index
-  %2 = tensor.empty(%0) : tensor<?x27xi64>
-  // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
-  %3 = tosa.transpose %2 {perms = array<i32: 0, 2>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: test_large_constant_permutation
-func.func @test_large_constant_permutation() {
-  %0 = arith.constant 6 : index
-  %2 = tensor.empty(%0) : tensor<?x27xi64>
-  // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
-  %3 = tosa.transpose %2 {perms = array<i32: 1185677355, 332462212>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
-  return
-}
 
 // -----
 
@@ -2061,14 +1957,6 @@ func.func @test_scalar_tile(%arg0: tensor<f32>) -> tensor<*xf32> {
 
 // -----
 
-func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
-  // expected-error at +1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
-  %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
-  return %1 : tensor<f32>
-}
-
-// -----
-
 // CHECK-LABEL: test_add_i1
 func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
   // expected-error at +1 {{'tosa.add' op illegal: operand/result data types not supported}}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
new file mode 100644
index 0000000000000..c49cbecd25c78
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -0,0 +1,126 @@
+//--------------------------------------------------------------------------------------------------
+// Test expected errors generated by verifier checks.
+//--------------------------------------------------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+func.func @test_transpose_io_rank_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21x1xf32> {
+  // expected-error at +1 {{'tosa.transpose' op expected input tensor rank to equal result tensor rank}}
+  %0 = tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>}: (tensor<13x21x3xf32>) -> tensor<3x13x21x1xf32>
+  return %0 : tensor<3x13x21x1xf32>
+}
+
+// -----
+
+func.func @test_transpose_rank0_perms() {
+  %14 = tensor.empty() : tensor<5x27xi64>
+  // expected-error at +1 {{'tosa.transpose' op expected perms attribute to have size 2 (input rank) but got size 0}}
+  %72 = tosa.transpose %14 {perms = array<i32> }: (tensor<5x27xi64>) -> tensor<?x?xi64>
+  return
+}
+
+// -----
+
+func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
+  // expected-error at +1 {{'tosa.transpose' op expected perms attribute to have size 3 (input rank) but got size 7}}
+  %0 = tosa.transpose %arg0 {perms = array<i32: 6, 5, 4, 3, 2, 1, 0> }: (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
+  return %0 : tensor<3x13x21xf32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>) -> tensor<?x?x?xf32> {
+  // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+  %0 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 0> }: (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
+  // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+  %1 = tosa.transpose %arg0 {perms = array<i32: -1, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
+  return %1 : tensor<*xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
+  // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+  %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
+  return %1 : tensor<*xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_num_elements(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> {
+  // expected-error at +1 {{'tosa.transpose' op expected input1 and output to have same numbers of elements, got 6 and 12}}
+  %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<3x2xi32>) -> tensor<3x4xi32>
+  return %1 : tensor<3x4xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
+  // expected-error at +1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}}
+  %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<3x2xi32>) -> tensor<3x2xi32>
+  return %1 : tensor<3x2xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> {
+  // expected-error at +1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}}
+  %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<2x?xi32>) -> tensor<3x4xi32>
+  return %1 : tensor<3x4xi32>
+}
+
+// -----
+
+func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> {
+  // expected-error at +1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}}
+  %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>} : (tensor<2x3xi32>) -> tensor<3x2xf32>
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_invalid_constant_permutation
+func.func @test_invalid_constant_permutation() {
+  %0 = tensor.empty() : tensor<3x4x5xi32>
+  // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+  %2 = tosa.transpose %0 {perms = array<i32: 3, 0, 1>}: (tensor<3x4x5xi32>) -> tensor<3x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_rank_size_constant_permutation
+func.func @test_rank_size_constant_permutation() {
+  %0 = arith.constant 6 : index
+  %2 = tensor.empty(%0) : tensor<?x27xi64>
+  // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+  %3 = tosa.transpose %2 {perms = array<i32: 0, 2>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_large_constant_permutation
+func.func @test_large_constant_permutation() {
+  %0 = arith.constant 6 : index
+  %2 = tensor.empty(%0) : tensor<?x27xi64>
+  // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+  %3 = tosa.transpose %2 {perms = array<i32: 1185677355, 332462212>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
+  return
+}
+
+// -----
+
+func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
+  return %1 : tensor<f32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/135219


More information about the Mlir-commits mailing list