[Mlir-commits] [mlir] [mlir][tosa] Add verifier for `tosa.transpose` (PR #75376)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 13 12:24:20 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Felix Schneider (ubfx)

<details>
<summary>Changes</summary>

This patch adds a verifier to `tosa.transpose` which fixes a crash.

Related: https://github.com/llvm/llvm-project/pull/74367

Fix https://github.com/llvm/llvm-project/issues/74479

---
Full diff: https://github.com/llvm/llvm-project/pull/75376.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+41) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+43) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (-12) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index a55a8015e09b70..9dde59f634d739 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1678,6 +1678,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose"> {
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 259fb6394669a2..661126f4df9976 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Matchers.h"
@@ -1054,6 +1055,46 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::TransposeOp::verify() {
+  TensorType inputType = getInput1().getType();
+  TensorType permType = getPerms().getType();
+  TensorType outputType = getOutput().getType();
+
+  if (permType.hasRank() && permType.getRank() != 1)
+    return emitOpError()
+           << "expected permutation tensor to be rank 1 but got rank "
+           << permType.getRank();
+  if (inputType.hasRank() && permType.hasRank())
+    if (!permType.isDynamicDim(0) &&
+        permType.getDimSize(0) != inputType.getRank())
+      return emitOpError() << "expected permutation tensor dim 0 to have size "
+                           << inputType.getRank()
+                           << " (input rank) but got size "
+                           << permType.getDimSize(0);
+  if (inputType.hasRank() && outputType.hasRank() &&
+      inputType.getRank() != outputType.getRank())
+    return emitOpError()
+           << "expected input tensor rank to equal result tensor rank";
+  if (outputType.hasRank() && permType.hasRank())
+    if (!permType.isDynamicDim(0) &&
+        permType.getDimSize(0) != outputType.getRank())
+      return emitOpError() << "expected permutation tensor dim 0 to have size "
+                           << outputType.getRank()
+                           << " (output rank) but got size "
+                           << permType.getDimSize(0);
+
+  SmallVector<int64_t> constantPerms;
+  if (succeeded(getConstantPerms(constantPerms))) {
+    // Assert that the permutation tensor has a rank, which means that the rank
+    // has been verified above.
+    assert(permType.hasRank() &&
+           "Unexpectedly found permutation tensor without rank");
+    if (!isPermutationVector(constantPerms))
+      return emitOpError() << "expected valid permutation tensor";
+  }
+  return success();
+}
+
 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     GatherOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 4a517cdec1fd7b..38ba48f365eabe 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -80,6 +80,49 @@ func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x
 
 // -----
 
+func.func @test_transpose_io_rank_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21x1xf32> {
+  // expected-error at +1 {{'tosa.transpose' op expected input tensor rank to equal result tensor rank}}
+  %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21x1xf32>
+  return %0 : tensor<3x13x21x1xf32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_perms_rank(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<3x13x21xf32> {
+  // expected-error at +1 {{'tosa.transpose' op expected permutation tensor to be rank 1 but got rank 2}}
+  %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<3x13x21xf32>
+  return %0 : tensor<3x13x21xf32>
+}
+
+// -----
+
+func.func @test_transpose_rank0_perms() {
+  %14 = tensor.empty() : tensor<5x27xi64>
+  %cst = tensor.empty() : tensor<i32>
+  // expected-error at +1 {{'tosa.transpose' op expected permutation tensor to be rank 1 but got rank 0}}
+  %72 = tosa.transpose %14, %cst : (tensor<5x27xi64>, tensor<i32>) -> tensor<?x?xi64>
+  return
+}
+
+// -----
+
+func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7xi32>) -> tensor<3x13x21xf32> {
+  // expected-error at +1 {{'tosa.transpose' op expected permutation tensor dim 0 to have size 3 (input rank) but got size 7}}
+  %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<7xi32>) -> tensor<3x13x21xf32>
+  return %0 : tensor<3x13x21xf32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>) -> tensor<?x?x?xf32> {
+  %perms = arith.constant dense<[2, 0, 0]> : tensor<3xi32>
+  // expected-error at +1 {{'tosa.transpose' op expected valid permutation tensor}}
+  %0 = tosa.transpose %arg0, %perms : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
 func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> {
   %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
   %1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index c240f5334c149e..f5cd8bd48de1cc 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1309,15 +1309,3 @@ func.func @test_large_constant_permutation() {
   %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
   return
 }
-
-// -----
-
-// CHECK-LABEL: test_rank0_transpose_perms
-// Fail to infer the shape but not crash.
-func.func @test_rank0_transpose_perms() {
-  %14 = tensor.empty() : tensor<5x27xi64>
-  %cst = tensor.empty() : tensor<i32>
-  // CHECK: tosa.transpose
-  %72 = tosa.transpose %14, %cst : (tensor<5x27xi64>, tensor<i32>) -> tensor<?x?xi64>
-  return
-}

``````````

</details>


https://github.com/llvm/llvm-project/pull/75376


More information about the Mlir-commits mailing list