[Mlir-commits] [mlir] [mlir][tosa] Check for unranked tensors during validation (PR #68509)

Sarthak Gupta llvmlistbot at llvm.org
Sun Oct 8 00:51:30 PDT 2023


https://github.com/gptsarthak created https://github.com/llvm/llvm-project/pull/68509

Fixes https://github.com/llvm/llvm-project/issues/67760
`levelCheckRank` ensures that the tensors for tosa operations are not unranked 

During tosa validation in `levelCheckRank`, we were trying to get the rank of a tensor without checking if it is ranked or unranked, which leads to an `assert` error. I see two ways to fix this:

- Only check `type.getRank() > tosa_level.MAX_RANK` if the tensor is ranked, and then proceed as usual. 
(like `if (type.hasRank() && type.getRank() > tosa_level.MAX_RANK)` , OR
- Throw an error for unranked tensors as result.

I have implemented the second method in this PR. Are there any better ways to fix this issue?

This is my first patch in mlir!



>From 9d9d0bf74d1ce75fb4e122501ac7a9c2af621706 Mon Sep 17 00:00:00 2001
From: gptsarthak <sarthakgpt95 at gmail.com>
Date: Sun, 8 Oct 2023 13:05:22 +0530
Subject: [PATCH] [mlir][tosa] Check for unranked tensors during validation

levelCheckRank ensures that the tensors for tosa operations are not unranked
Fixes https://github.com/llvm/llvm-project/issues/67760
---
 mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 4 ++++
 mlir/test/Dialect/Tosa/level_check.mlir             | 9 +++++++++
 2 files changed, 13 insertions(+)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 52885e69c3924f2..078200844c35b33 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -150,6 +150,10 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   bool levelCheckRank(Operation *op, const Value &v,
                       const std::string &check_desc) {
     if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
+      if (!type.hasRank()) {
+        op->emitOpError() << "failed level check: unranked tensor";
+        return false;
+      }
       if (type.getRank() > tosa_level.MAX_RANK) {
         op->emitOpError() << "failed level check: " << check_desc;
         return false;
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index e7fdf8af409b564..376cf3baa82a99c 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -695,4 +695,13 @@ func.func @test_custom(%arg0: tensor<1x1x1x1x1x1x10xi32>) -> tensor<1x1x1x1x1x1x
   return %0 : tensor<1x1x1x1x1x1x10xi32>
 }
 
+// -----
+// CHECK-LABEL: unranked_tensor
+func.func @test_unranked_tensor(%arg0: tensor<1x1x1x1x13x21xf32>) {
+  // expected-error at +1 {{'tosa.slice' op failed level check: unranked tensor}}
+  %0 = "tosa.slice"(%arg0) {start = array<i64: 0, 0, 0, 0, 6, 8, 0>, size = array<i64: 1, 1, 1, 1, 4, 11, 1>} :
+          (tensor<1x1x1x1x13x21xf32>) -> tensor<*xf32>
+  return
+}
+
 



More information about the Mlir-commits mailing list