[Mlir-commits] [mlir] b6d67af - [mlir][tosa] Add verifier for tosa.tile, fix shape inference crash (#70972)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 2 11:39:08 PDT 2023


Author: Felix Schneider
Date: 2023-11-02T19:39:04+01:00
New Revision: b6d67af2098fa92a557b72b7508fdd4f5e3488eb

URL: https://github.com/llvm/llvm-project/commit/b6d67af2098fa92a557b72b7508fdd4f5e3488eb
DIFF: https://github.com/llvm/llvm-project/commit/b6d67af2098fa92a557b72b7508fdd4f5e3488eb.diff

LOG: [mlir][tosa] Add verifier for tosa.tile, fix shape inference crash (#70972)

This patch adds an verifier to `tosa.tile` which checks input/output
ranks and the length of the `multiples` array. The patch also fixes a
crash in the shape inference when an invalid `multiples` array is
supplied.

Fix https://github.com/llvm/llvm-project/issues/70415

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir
    mlir/test/Dialect/Tosa/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 81b9e93c2095f57..0a2f3271c37d212 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1644,6 +1644,7 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4ec6714a7e02a8b..259799725622269 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -863,7 +863,8 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
     outputShape.resize(multiples.size(), ShapedType::kDynamic);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
-  }
+  } else if (inputShape.getRank() != multiples.size())
+    return failure();
 
   // Any non dynamic dimension can be multiplied to a known size.
   outputShape.reserve(multiples.size());
@@ -878,6 +879,24 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::TileOp::verify() {
+  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
+  ShapedType outputType = llvm::cast<ShapedType>(getType());
+  auto multiples = getMultiples();
+
+  if (inputType.hasRank()) {
+    if (inputType.getRank() != multiples.size())
+      return emitOpError("expect 'multiples' array to have length ")
+             << inputType.getRank() << " but got " << multiples.size() << ".";
+    if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
+      return emitOpError("expect same input and output tensor rank.");
+  } else if (outputType.hasRank() && outputType.getRank() != multiples.size())
+    return emitOpError("expect 'multiples' array to have length ")
+           << outputType.getRank() << " but got " << multiples.size() << ".";
+
+  return success();
+}
+
 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   if (l.size() != r.size() || l.size() != 1)
     return false;

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 102c9ed1578cde9..fd51d287bca0580 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -603,3 +603,13 @@ func.func nested @fold_reduce_rank_zero() {
   %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @fold_tile_rank_zero
+func.func nested @fold_tile_rank_zero() -> tensor<i32> {
+  // CHECK-NOT: tosa.tile
+  %0 = tensor.empty() : tensor<i32>
+  %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
+  return %1 : tensor<i32>
+}

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8e23a1fde04bc82..4a517cdec1fd7bc 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -329,3 +329,12 @@ func.func @test_slice_invalid_size() {
   %1 = tosa.slice %0 {size = array<i64: 1>, start = array<i64: 1, 1, 1>} : (tensor<4x31x31xf32>) -> tensor<*xf32>
   return
 }
+
+// -----
+
+func.func @test_tile_invalid_multiples() {
+  %0 = tensor.empty() : tensor<4x31x31xf32>
+  // expected-error at +1 {{'tosa.tile' op expect 'multiples' array to have length 3 but got 0.}}
+  %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
+  return
+}


        


More information about the Mlir-commits mailing list