[Mlir-commits] [mlir] [mlir][tosa] Check for unranked tensors during validation (PR #68509)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 10 14:03:41 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Sarthak Gupta (gptsarthak)

<details>
<summary>Changes</summary>

Fixes https://github.com/llvm/llvm-project/issues/67760
`levelCheckRank` ensures that the tensors for tosa operations are not unranked 

During tosa validation in `levelCheckRank`, we were trying to get the rank of a tensor without checking if it is ranked or unranked, which leads to an `assert` error. I see two ways to fix this:

- Only check `type.getRank() > tosa_level.MAX_RANK` if the tensor is ranked, and then proceed as usual. 
(like `if (type.hasRank() && type.getRank() > tosa_level.MAX_RANK)` , OR
- Throw an error for unranked tensors as result.

I have implemented the second method in this PR. Are there any better ways to fix this issue?

This is my first patch in mlir!



---
Full diff: https://github.com/llvm/llvm-project/pull/68509.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+4) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+9) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 52885e69c3924f2..078200844c35b33 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -150,6 +150,10 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   bool levelCheckRank(Operation *op, const Value &v,
                       const std::string &check_desc) {
     if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
+      if (!type.hasRank()) {
+        op->emitOpError() << "failed level check: unranked tensor";
+        return false;
+      }
       if (type.getRank() > tosa_level.MAX_RANK) {
         op->emitOpError() << "failed level check: " << check_desc;
         return false;
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index e7fdf8af409b564..376cf3baa82a99c 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -695,4 +695,13 @@ func.func @test_custom(%arg0: tensor<1x1x1x1x1x1x10xi32>) -> tensor<1x1x1x1x1x1x
   return %0 : tensor<1x1x1x1x1x1x10xi32>
 }
 
+// -----
+// CHECK-LABEL: unranked_tensor
+func.func @test_unranked_tensor(%arg0: tensor<1x1x1x1x13x21xf32>) {
+  // expected-error at +1 {{'tosa.slice' op failed level check: unranked tensor}}
+  %0 = "tosa.slice"(%arg0) {start = array<i64: 0, 0, 0, 0, 6, 8, 0>, size = array<i64: 1, 1, 1, 1, 4, 11, 1>} :
+          (tensor<1x1x1x1x13x21xf32>) -> tensor<*xf32>
+  return
+}
+
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/68509


More information about the Mlir-commits mailing list