[Mlir-commits] [mlir] [mlir][tosa] Add verifier for tosa.tile, fix shape inference crash (PR #70972)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 2 00:15:57 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Felix Schneider (ubfx)

<details>
<summary>Changes</summary>

This patch adds an verifier to `tosa.tile` which checks input/output ranks and the length of the `multiples` array. The patch also fixes a crash in the shape inference when an invalid `multiples` array is supplied.

Fix https://github.com/llvm/llvm-project/issues/70415

---
Full diff: https://github.com/llvm/llvm-project/pull/70972.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+22-4) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+10) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+9) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 81b9e93c2095f57..0a2f3271c37d212 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1644,6 +1644,7 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4ec6714a7e02a8b..375a7bbe38e8ec6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -863,7 +863,8 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
     outputShape.resize(multiples.size(), ShapedType::kDynamic);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
-  }
+  } else if (inputShape.getRank() != multiples.size())
+    return failure();
 
   // Any non dynamic dimension can be multiplied to a known size.
   outputShape.reserve(multiples.size());
@@ -878,6 +879,24 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::TileOp::verify() {
+  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
+  ShapedType outputType = llvm::cast<ShapedType>(getType());
+  auto multiples = getMultiples();
+
+  if (inputType.hasRank()) {
+    if (inputType.getRank() != multiples.size())
+      return emitOpError("expect 'multiples' array to have length ")
+             << inputType.getRank() << " but got " << multiples.size() << ".";
+    if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
+      return emitOpError("expect same input and output tensor rank.");
+  } else if (outputType.hasRank() && outputType.getRank() != multiples.size())
+    return emitOpError("expect 'multiples' array to have length ")
+           << outputType.getRank() << " but got " << multiples.size() << ".";
+
+  return success();
+}
+
 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   if (l.size() != r.size() || l.size() != 1)
     return false;
@@ -1830,9 +1849,8 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (functionType.getNumInputs() != operands.size()) {
     return parser.emitError(typeLoc)
-           << "expected as many input types as operands "
-           << "(expected " << operands.size() << " got "
-           << functionType.getNumInputs() << ")";
+           << "expected as many input types as operands " << "(expected "
+           << operands.size() << " got " << functionType.getNumInputs() << ")";
   }
 
   // Resolve input operands.
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 102c9ed1578cde9..fd51d287bca0580 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -603,3 +603,13 @@ func.func nested @fold_reduce_rank_zero() {
   %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @fold_tile_rank_zero
+func.func nested @fold_tile_rank_zero() -> tensor<i32> {
+  // CHECK-NOT: tosa.tile
+  %0 = tensor.empty() : tensor<i32>
+  %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
+  return %1 : tensor<i32>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 8e23a1fde04bc82..4a517cdec1fd7bc 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -329,3 +329,12 @@ func.func @test_slice_invalid_size() {
   %1 = tosa.slice %0 {size = array<i64: 1>, start = array<i64: 1, 1, 1>} : (tensor<4x31x31xf32>) -> tensor<*xf32>
   return
 }
+
+// -----
+
+func.func @test_tile_invalid_multiples() {
+  %0 = tensor.empty() : tensor<4x31x31xf32>
+  // expected-error at +1 {{'tosa.tile' op expect 'multiples' array to have length 3 but got 0.}}
+  %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/70972


More information about the Mlir-commits mailing list