[Mlir-commits] [mlir] [mlir][tosa] Check for unranked tensors during validation (PR #68509)

Sarthak Gupta llvmlistbot at llvm.org
Wed Oct 11 01:29:22 PDT 2023


https://github.com/gptsarthak updated https://github.com/llvm/llvm-project/pull/68509

>From cf45b9d2dc8c120b0269992ff6740abbe72f93f3 Mon Sep 17 00:00:00 2001
From: gptsarthak <sarthakgpt95 at gmail.com>
Date: Wed, 11 Oct 2023 13:58:59 +0530
Subject: [PATCH] [mlir][tosa] Check for unranked tensors during validation

levelCheckRank ensures that the tensors for tosa operations are not unranked
Fixes https://github.com/llvm/llvm-project/issues/67760
---
 mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 4 ++++
 mlir/test/Dialect/Tosa/level_check.mlir             | 9 +++++++++
 2 files changed, 13 insertions(+)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 52885e69c3924f2..078200844c35b33 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -150,6 +150,10 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   bool levelCheckRank(Operation *op, const Value &v,
                       const std::string &check_desc) {
     if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
+      if (!type.hasRank()) {
+        op->emitOpError() << "failed level check: unranked tensor";
+        return false;
+      }
       if (type.getRank() > tosa_level.MAX_RANK) {
         op->emitOpError() << "failed level check: " << check_desc;
         return false;
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index e7fdf8af409b564..839ee4055e7aab9 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -695,4 +695,13 @@ func.func @test_custom(%arg0: tensor<1x1x1x1x1x1x10xi32>) -> tensor<1x1x1x1x1x1x
   return %0 : tensor<1x1x1x1x1x1x10xi32>
 }
 
+// -----
+
+// CHECK-LABEL: unranked_tensor
+func.func @test_unranked_tensor(%arg0: tensor<*xf32>) {
+  // expected-error at +1 {{'tosa.slice' op failed level check: unranked tensor}}
+  %0 = "tosa.slice"(%arg0) {start = array<i64>, size = array<i64>} :
+          (tensor<*xf32>) -> tensor<*xf32>
+  return
+}
 



More information about the Mlir-commits mailing list