[Mlir-commits] [mlir] 3651f37 - [mlir][tosa] Check for unranked tensors during validation (#68509)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 23 01:45:07 PDT 2023


Author: Sarthak Gupta
Date: 2023-10-23T09:45:03+01:00
New Revision: 3651f377f632d9152f2bd8fc2f9302ec9f9bdd5e

URL: https://github.com/llvm/llvm-project/commit/3651f377f632d9152f2bd8fc2f9302ec9f9bdd5e
DIFF: https://github.com/llvm/llvm-project/commit/3651f377f632d9152f2bd8fc2f9302ec9f9bdd5e.diff

LOG: [mlir][tosa] Check for unranked tensors during validation (#68509)

Fixes https://github.com/llvm/llvm-project/issues/67760
`levelCheckRank` ensures that the tensors for tosa operations are not
unranked

During tosa validation in `levelCheckRank`, we were trying to get the
rank of a tensor without checking if it is ranked or unranked, which
leads to an `assert` error. I see two ways to fix this:

- Only check `type.getRank() > tosa_level.MAX_RANK` if the tensor is
ranked, and then proceed as usual.
(like `if (type.hasRank() && type.getRank() > tosa_level.MAX_RANK)` , OR
- Throw an error for unranked tensors as result.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
    mlir/test/Dialect/Tosa/level_check.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 8a2254fc24effe2..424a31175d61707 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -156,6 +156,10 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   bool levelCheckRank(Operation *op, const Value &v,
                       const std::string &checkDesc) {
     if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
+      if (!type.hasRank()) {
+        op->emitOpError() << "failed level check: unranked tensor";
+        return false;
+      }
       if (type.getRank() > tosaLevel.MAX_RANK) {
         op->emitOpError() << "failed level check: " << checkDesc;
         return false;

diff  --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 68238087f5c2523..443baf8a5f6e4f8 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -695,4 +695,13 @@ func.func @test_custom(%arg0: tensor<1x1x1x1x1x1x10xi32>) -> tensor<1x1x1x1x1x1x
   return %0 : tensor<1x1x1x1x1x1x10xi32>
 }
 
+// -----
+
+// CHECK-LABEL: unranked_tensor
+func.func @test_unranked_tensor(%arg0: tensor<*xf32>) {
+  // expected-error at +1 {{'tosa.slice' op failed level check: unranked tensor}}
+  %0 = "tosa.slice"(%arg0) {start = array<i64>, size = array<i64>} :
+          (tensor<*xf32>) -> tensor<*xf32>
+  return
+}
 


        


More information about the Mlir-commits mailing list