[Mlir-commits] [mlir] [mlir][tosa] Add verifier for tosa.tile, fix shape inference crash (PR #70972)

Felix Schneider llvmlistbot at llvm.org
Wed Nov 1 11:38:38 PDT 2023


https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/70972

This patch adds an verifier to `tosa.tile` which checks input/output ranks and the length of the `multiples` array. The patch also fixes a crash in the shape inference when an invalid `multiples` array is supplied.

Fix https://github.com/llvm/llvm-project/issues/70415

>From a54a1036260cd9be24fc619a0924922459e8cdf0 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Wed, 1 Nov 2023 19:33:45 +0100
Subject: [PATCH] [mlir][tosa] Add verifier for tosa.tile, fix shape inference
 crash

This patch adds an verifier to `tosa.tile` which checks input/output
ranks and the length of the `multiples` array. The patch also fixes
a crash in the shape inference when an invalid `multiples` array
is supplied.

Fix https://github.com/llvm/llvm-project/issues/70415
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td |  1 +
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp         | 26 +++++++++++++++++---
 mlir/test/Dialect/Tosa/canonicalize.mlir     | 10 ++++++++
 mlir/test/Dialect/Tosa/invalid.mlir          |  9 +++++++
 4 files changed, 42 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 81b9e93c2095f57..0a2f3271c37d212 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1644,6 +1644,7 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4ec6714a7e02a8b..375a7bbe38e8ec6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -863,7 +863,8 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
     outputShape.resize(multiples.size(), ShapedType::kDynamic);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
-  }
+  } else if (inputShape.getRank() != multiples.size())
+    return failure();
 
   // Any non dynamic dimension can be multiplied to a known size.
   outputShape.reserve(multiples.size());
@@ -878,6 +879,24 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::TileOp::verify() {
+  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
+  ShapedType outputType = llvm::cast<ShapedType>(getType());
+  auto multiples = getMultiples();
+
+  if (inputType.hasRank()) {
+    if (inputType.getRank() != multiples.size())
+      return emitOpError("expect 'multiples' array to have length ")
+             << inputType.getRank() << " but got " << multiples.size() << ".";
+    if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
+      return emitOpError("expect same input and output tensor rank.");
+  } else if (outputType.hasRank() && outputType.getRank() != multiples.size())
+    return emitOpError("expect 'multiples' array to have length ")
+           << outputType.getRank() << " but got " << multiples.size() << ".";
+
+  return success();
+}
+
 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   if (l.size() != r.size() || l.size() != 1)
     return false;
@@ -1830,9 +1849,8 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (functionType.getNumInputs() != operands.size()) {
     return parser.emitError(typeLoc)
-           << "expected as many input types as operands "
-           << "(expected " << operands.size() << " got "
-           << functionType.getNumInputs() << ")";
+           << "expected as many input types as operands " << "(expected "
+           << operands.size() << " got " << functionType.getNumInputs() << ")";
   }
 
   // Resolve input operands.
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 102c9ed1578cde9..ecb937c337a5830 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -603,3 +603,13 @@ func.func nested @fold_reduce_rank_zero() {
   %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @fold_tile_rank_zero
+func.func nested @fold_tile_rank_zero() {
+  // CHECK-NOT: tosa.tile
+  %0 = tensor.empty() : tensor<i32>
+  %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
+  return
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8e23a1fde04bc82..4a517cdec1fd7bc 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -329,3 +329,12 @@ func.func @test_slice_invalid_size() {
   %1 = tosa.slice %0 {size = array<i64: 1>, start = array<i64: 1, 1, 1>} : (tensor<4x31x31xf32>) -> tensor<*xf32>
   return
 }
+
+// -----
+
+func.func @test_tile_invalid_multiples() {
+  %0 = tensor.empty() : tensor<4x31x31xf32>
+  // expected-error at +1 {{'tosa.tile' op expect 'multiples' array to have length 3 but got 0.}}
+  %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
+  return
+}



More information about the Mlir-commits mailing list