[Mlir-commits] [mlir] 3651f37 - [mlir][tosa] Check for unranked tensors during validation (#68509)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 23 01:45:07 PDT 2023
Author: Sarthak Gupta
Date: 2023-10-23T09:45:03+01:00
New Revision: 3651f377f632d9152f2bd8fc2f9302ec9f9bdd5e
URL: https://github.com/llvm/llvm-project/commit/3651f377f632d9152f2bd8fc2f9302ec9f9bdd5e
DIFF: https://github.com/llvm/llvm-project/commit/3651f377f632d9152f2bd8fc2f9302ec9f9bdd5e.diff
LOG: [mlir][tosa] Check for unranked tensors during validation (#68509)
Fixes https://github.com/llvm/llvm-project/issues/67760
`levelCheckRank` ensures that the tensors for tosa operations are not
unranked
During tosa validation in `levelCheckRank`, we were trying to get the
rank of a tensor without checking if it is ranked or unranked, which
leads to an `assert` error. I see two ways to fix this:
- Only check `type.getRank() > tosa_level.MAX_RANK` if the tensor is
ranked, and then proceed as usual.
(like `if (type.hasRank() && type.getRank() > tosa_level.MAX_RANK)` , OR
- Throw an error for unranked tensors as result.
Added:
Modified:
mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
mlir/test/Dialect/Tosa/level_check.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 8a2254fc24effe2..424a31175d61707 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -156,6 +156,10 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
bool levelCheckRank(Operation *op, const Value &v,
const std::string &checkDesc) {
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
+ if (!type.hasRank()) {
+ op->emitOpError() << "failed level check: unranked tensor";
+ return false;
+ }
if (type.getRank() > tosaLevel.MAX_RANK) {
op->emitOpError() << "failed level check: " << checkDesc;
return false;
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 68238087f5c2523..443baf8a5f6e4f8 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -695,4 +695,13 @@ func.func @test_custom(%arg0: tensor<1x1x1x1x1x1x10xi32>) -> tensor<1x1x1x1x1x1x
return %0 : tensor<1x1x1x1x1x1x10xi32>
}
+// -----
+
+// CHECK-LABEL: unranked_tensor
+func.func @test_unranked_tensor(%arg0: tensor<*xf32>) {
+ // expected-error at +1 {{'tosa.slice' op failed level check: unranked tensor}}
+ %0 = "tosa.slice"(%arg0) {start = array<i64>, size = array<i64>} :
+ (tensor<*xf32>) -> tensor<*xf32>
+ return
+}
More information about the Mlir-commits
mailing list