[Mlir-commits] [mlir] [mlir][tosa] Add error_if checks for Transpose (PR #135219)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 10 10:38:19 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Tai Ly (Tai78641)
<details>
<summary>Changes</summary>
This adds missing error_if checking for Transpose Op
also moved all transpose op's verifier tests from
invalid.mlir to verifier.mlir
---
Full diff: https://github.com/llvm/llvm-project/pull/135219.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+29-19)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (-112)
- (added) mlir/test/Dialect/Tosa/verifier.mlir (+126)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5941be8403480..1ba2cda784463 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1964,23 +1964,28 @@ LogicalResult tosa::TransposeOp::verify() {
.failed()) {
return failure();
}
- TensorType inputType = getInput1().getType();
- TensorType outputType = getOutput().getType();
+
+ const ShapeAdaptor inputShape(getInput1().getType());
+ const ShapeAdaptor outputShape(getOutput().getType());
+
const llvm::ArrayRef<int32_t> constantPerms = getPerms();
- if (inputType.hasRank() &&
- constantPerms.size() != static_cast<size_t>(inputType.getRank()))
+ if (inputShape.hasRank() &&
+ constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
return emitOpError() << "expected perms attribute to have size "
- << inputType.getRank() << " (input rank) but got size "
+ << inputShape.getRank()
+ << " (input rank) but got size "
<< constantPerms.size();
- if (inputType.hasRank() && outputType.hasRank() &&
- inputType.getRank() != outputType.getRank())
+
+ if (inputShape.hasRank() && outputShape.hasRank() &&
+ inputShape.getRank() != outputShape.getRank())
return emitOpError()
<< "expected input tensor rank to equal result tensor rank";
- if (outputType.hasRank() &&
- constantPerms.size() != static_cast<size_t>(outputType.getRank()))
+
+ if (outputShape.hasRank() &&
+ constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
return emitOpError() << "expected perms attribute to have size "
- << outputType.getRank()
+ << outputShape.getRank()
<< " (output rank) but got size "
<< constantPerms.size();
@@ -1993,22 +1998,27 @@ LogicalResult tosa::TransposeOp::verify() {
constantPerms, [](int32_t v) -> int64_t { return v; }))))
return emitOpError() << "expected valid permutation indices";
+ // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
+ if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
+ inputShape.getNumElements() != outputShape.getNumElements())
+ return emitOpError() << "expected input1 and output to have same numbers "
+ "of elements, got "
+ << inputShape.getNumElements() << " and "
+ << outputShape.getNumElements();
+
// Verify that the types of the input and output tensors are properly
// permuted.
- if (inputType.hasRank() && outputType.hasRank()) {
- assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
- inputType.getRank() == outputType.getRank());
-
- for (auto i = 0; i < outputType.getRank(); i++) {
- if (inputType.isDynamicDim(constantPerms[i]) ||
- outputType.isDynamicDim(i))
+ if (inputShape.hasRank() && outputShape.hasRank()) {
+ for (auto i = 0; i < outputShape.getRank(); i++) {
+ if (inputShape.isDynamicDim(constantPerms[i]) ||
+ outputShape.isDynamicDim(i))
continue;
- if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
+ if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
return emitOpError()
<< "expected output tensor dim " << i << " to match "
<< "input dim " << constantPerms[i] << " with value of "
- << inputType.getDimSize(constantPerms[i]);
+ << inputShape.getDimSize(constantPerms[i]);
}
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 55a9fcb15bbc7..3310919d406a2 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -368,79 +368,6 @@ func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor
// -----
-func.func @test_transpose_io_rank_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21x1xf32> {
- // expected-error at +1 {{'tosa.transpose' op expected input tensor rank to equal result tensor rank}}
- %0 = tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>}: (tensor<13x21x3xf32>) -> tensor<3x13x21x1xf32>
- return %0 : tensor<3x13x21x1xf32>
-}
-
-// -----
-
-func.func @test_transpose_rank0_perms() {
- %14 = tensor.empty() : tensor<5x27xi64>
- // expected-error at +1 {{'tosa.transpose' op expected perms attribute to have size 2 (input rank) but got size 0}}
- %72 = tosa.transpose %14 {perms = array<i32> }: (tensor<5x27xi64>) -> tensor<?x?xi64>
- return
-}
-
-// -----
-
-func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
- // expected-error at +1 {{'tosa.transpose' op expected perms attribute to have size 3 (input rank) but got size 7}}
- %0 = tosa.transpose %arg0 {perms = array<i32: 6, 5, 4, 3, 2, 1, 0> }: (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
- return %0 : tensor<3x13x21xf32>
-}
-
-// -----
-
-func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>) -> tensor<?x?x?xf32> {
- // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
- %0 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 0> }: (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
- return %0 : tensor<?x?x?xf32>
-}
-
-// -----
-
-func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
- // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
- %1 = tosa.transpose %arg0 {perms = array<i32: -1, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
- return %1 : tensor<*xi32>
-}
-
-// -----
-
-func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
- // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
- %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
- return %1 : tensor<*xi32>
-}
-
-// -----
-
-func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> {
- // expected-error at +1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}}
- %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<3x2xi32>) -> tensor<3x4xi32>
- return %1 : tensor<3x4xi32>
-}
-
-// -----
-
-func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> {
- // expected-error at +1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}}
- %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<2x?xi32>) -> tensor<3x4xi32>
- return %1 : tensor<3x4xi32>
-}
-
-// -----
-
-func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> {
- // expected-error at +1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}}
- %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>} : (tensor<2x3xi32>) -> tensor<3x2xf32>
- return %1 : tensor<3x2xf32>
-}
-
-// -----
-
func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{'tosa.reduce_sum' op inferred type(s) 'tensor<1x3x4x5xf32>' are incompatible with return type(s) of operation 'tensor<1x3x4x5xi32>'}}
@@ -783,37 +710,6 @@ func.func @test_tile_io_rank_mismatch() {
return
}
-// -----
-
-// CHECK-LABEL: @test_invalid_constant_permutation
-func.func @test_invalid_constant_permutation() {
- %0 = tensor.empty() : tensor<3x4x5xi32>
- // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
- %2 = tosa.transpose %0 {perms = array<i32: 3, 0, 1>}: (tensor<3x4x5xi32>) -> tensor<3x4x5xi32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: test_rank_size_constant_permutation
-func.func @test_rank_size_constant_permutation() {
- %0 = arith.constant 6 : index
- %2 = tensor.empty(%0) : tensor<?x27xi64>
- // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
- %3 = tosa.transpose %2 {perms = array<i32: 0, 2>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
- return
-}
-
-// -----
-
-// CHECK-LABEL: test_large_constant_permutation
-func.func @test_large_constant_permutation() {
- %0 = arith.constant 6 : index
- %2 = tensor.empty(%0) : tensor<?x27xi64>
- // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
- %3 = tosa.transpose %2 {perms = array<i32: 1185677355, 332462212>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
- return
-}
// -----
@@ -2061,14 +1957,6 @@ func.func @test_scalar_tile(%arg0: tensor<f32>) -> tensor<*xf32> {
// -----
-func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
- // expected-error at +1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
- %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
- return %1 : tensor<f32>
-}
-
-// -----
-
// CHECK-LABEL: test_add_i1
func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
// expected-error at +1 {{'tosa.add' op illegal: operand/result data types not supported}}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
new file mode 100644
index 0000000000000..c49cbecd25c78
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -0,0 +1,126 @@
+//--------------------------------------------------------------------------------------------------
+// Test expected errors generated by verifier checks.
+//--------------------------------------------------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+func.func @test_transpose_io_rank_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21x1xf32> {
+ // expected-error at +1 {{'tosa.transpose' op expected input tensor rank to equal result tensor rank}}
+ %0 = tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>}: (tensor<13x21x3xf32>) -> tensor<3x13x21x1xf32>
+ return %0 : tensor<3x13x21x1xf32>
+}
+
+// -----
+
+func.func @test_transpose_rank0_perms() {
+ %14 = tensor.empty() : tensor<5x27xi64>
+ // expected-error at +1 {{'tosa.transpose' op expected perms attribute to have size 2 (input rank) but got size 0}}
+ %72 = tosa.transpose %14 {perms = array<i32> }: (tensor<5x27xi64>) -> tensor<?x?xi64>
+ return
+}
+
+// -----
+
+func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
+ // expected-error at +1 {{'tosa.transpose' op expected perms attribute to have size 3 (input rank) but got size 7}}
+ %0 = tosa.transpose %arg0 {perms = array<i32: 6, 5, 4, 3, 2, 1, 0> }: (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
+ return %0 : tensor<3x13x21xf32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>) -> tensor<?x?x?xf32> {
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+ %0 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 0> }: (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+ %1 = tosa.transpose %arg0 {perms = array<i32: -1, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
+ return %1 : tensor<*xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+ %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
+ return %1 : tensor<*xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_num_elements(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> {
+ // expected-error at +1 {{'tosa.transpose' op expected input1 and output to have same numbers of elements, got 6 and 12}}
+ %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<3x2xi32>) -> tensor<3x4xi32>
+ return %1 : tensor<3x4xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
+ // expected-error at +1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}}
+ %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<3x2xi32>) -> tensor<3x2xi32>
+ return %1 : tensor<3x2xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> {
+ // expected-error at +1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}}
+ %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<2x?xi32>) -> tensor<3x4xi32>
+ return %1 : tensor<3x4xi32>
+}
+
+// -----
+
+func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> {
+ // expected-error at +1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}}
+ %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>} : (tensor<2x3xi32>) -> tensor<3x2xf32>
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_invalid_constant_permutation
+func.func @test_invalid_constant_permutation() {
+ %0 = tensor.empty() : tensor<3x4x5xi32>
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+ %2 = tosa.transpose %0 {perms = array<i32: 3, 0, 1>}: (tensor<3x4x5xi32>) -> tensor<3x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_rank_size_constant_permutation
+func.func @test_rank_size_constant_permutation() {
+ %0 = arith.constant 6 : index
+ %2 = tensor.empty(%0) : tensor<?x27xi64>
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+ %3 = tosa.transpose %2 {perms = array<i32: 0, 2>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_large_constant_permutation
+func.func @test_large_constant_permutation() {
+ %0 = arith.constant 6 : index
+ %2 = tensor.empty(%0) : tensor<?x27xi64>
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+ %3 = tosa.transpose %2 {perms = array<i32: 1185677355, 332462212>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
+ return
+}
+
+// -----
+
+func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
+ return %1 : tensor<f32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/135219
More information about the Mlir-commits
mailing list