[Mlir-commits] [mlir] [mlir][tosa] Add verifier for tosa.tile, fix shape inference crash (PR #70972)

Felix Schneider llvmlistbot at llvm.org
Thu Nov 2 00:25:30 PDT 2023


https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/70972

>From a54a1036260cd9be24fc619a0924922459e8cdf0 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Wed, 1 Nov 2023 19:33:45 +0100
Subject: [PATCH 1/3] [mlir][tosa] Add verifier for tosa.tile, fix shape
 inference crash

This patch adds an verifier to `tosa.tile` which checks input/output
ranks and the length of the `multiples` array. The patch also fixes
a crash in the shape inference when an invalid `multiples` array
is supplied.

Fix https://github.com/llvm/llvm-project/issues/70415
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td |  1 +
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp         | 26 +++++++++++++++++---
 mlir/test/Dialect/Tosa/canonicalize.mlir     | 10 ++++++++
 mlir/test/Dialect/Tosa/invalid.mlir          |  9 +++++++
 4 files changed, 42 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 81b9e93c2095f57..0a2f3271c37d212 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1644,6 +1644,7 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4ec6714a7e02a8b..375a7bbe38e8ec6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -863,7 +863,8 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
     outputShape.resize(multiples.size(), ShapedType::kDynamic);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
-  }
+  } else if (inputShape.getRank() != multiples.size())
+    return failure();
 
   // Any non dynamic dimension can be multiplied to a known size.
   outputShape.reserve(multiples.size());
@@ -878,6 +879,24 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::TileOp::verify() {
+  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
+  ShapedType outputType = llvm::cast<ShapedType>(getType());
+  auto multiples = getMultiples();
+
+  if (inputType.hasRank()) {
+    if (inputType.getRank() != multiples.size())
+      return emitOpError("expect 'multiples' array to have length ")
+             << inputType.getRank() << " but got " << multiples.size() << ".";
+    if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
+      return emitOpError("expect same input and output tensor rank.");
+  } else if (outputType.hasRank() && outputType.getRank() != multiples.size())
+    return emitOpError("expect 'multiples' array to have length ")
+           << outputType.getRank() << " but got " << multiples.size() << ".";
+
+  return success();
+}
+
 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   if (l.size() != r.size() || l.size() != 1)
     return false;
@@ -1830,9 +1849,8 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (functionType.getNumInputs() != operands.size()) {
     return parser.emitError(typeLoc)
-           << "expected as many input types as operands "
-           << "(expected " << operands.size() << " got "
-           << functionType.getNumInputs() << ")";
+           << "expected as many input types as operands " << "(expected "
+           << operands.size() << " got " << functionType.getNumInputs() << ")";
   }
 
   // Resolve input operands.
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 102c9ed1578cde9..ecb937c337a5830 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -603,3 +603,13 @@ func.func nested @fold_reduce_rank_zero() {
   %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @fold_tile_rank_zero
+func.func nested @fold_tile_rank_zero() {
+  // CHECK-NOT: tosa.tile
+  %0 = tensor.empty() : tensor<i32>
+  %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
+  return
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8e23a1fde04bc82..4a517cdec1fd7bc 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -329,3 +329,12 @@ func.func @test_slice_invalid_size() {
   %1 = tosa.slice %0 {size = array<i64: 1>, start = array<i64: 1, 1, 1>} : (tensor<4x31x31xf32>) -> tensor<*xf32>
   return
 }
+
+// -----
+
+func.func @test_tile_invalid_multiples() {
+  %0 = tensor.empty() : tensor<4x31x31xf32>
+  // expected-error at +1 {{'tosa.tile' op expect 'multiples' array to have length 3 but got 0.}}
+  %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
+  return
+}

>From fd7b1ebe8fe120a20590ca2e5c661a4d095abddc Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Thu, 2 Nov 2023 08:15:00 +0100
Subject: [PATCH 2/3] add correct return to test

---
 mlir/test/Dialect/Tosa/canonicalize.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index ecb937c337a5830..fd51d287bca0580 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -607,9 +607,9 @@ func.func nested @fold_reduce_rank_zero() {
 // -----
 
 // CHECK-LABEL: @fold_tile_rank_zero
-func.func nested @fold_tile_rank_zero() {
+func.func nested @fold_tile_rank_zero() -> tensor<i32> {
   // CHECK-NOT: tosa.tile
   %0 = tensor.empty() : tensor<i32>
   %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
-  return
+  return %1 : tensor<i32>
 }

>From 9e2fd77be5eede329a83a2824271f993678a4497 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Thu, 2 Nov 2023 08:25:16 +0100
Subject: [PATCH 3/3] accidentally clang-formatted code that wasn't mine

---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 375a7bbe38e8ec6..259799725622269 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1849,8 +1849,9 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (functionType.getNumInputs() != operands.size()) {
     return parser.emitError(typeLoc)
-           << "expected as many input types as operands " << "(expected "
-           << operands.size() << " got " << functionType.getNumInputs() << ")";
+           << "expected as many input types as operands "
+           << "(expected " << operands.size() << " got "
+           << functionType.getNumInputs() << ")";
   }
 
   // Resolve input operands.



More information about the Mlir-commits mailing list