[Mlir-commits] [mlir] [mlir][tosa] Add verifier for `tosa.table` (PR #103708)
Longsheng Mou
llvmlistbot at llvm.org
Wed Aug 14 00:26:32 PDT 2024
https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/103708
This patch adds a verifier to `tosa.table` which fixes a crash. Fix #103086.
>From e42361ad83eded1c5168ad7679353a935904fcf4 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Wed, 14 Aug 2024 15:22:22 +0800
Subject: [PATCH] [mlir][tosa] Add verifier for `tosa.table`
This patch adds a verifier to `tosa.table` which fixes a crash.
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 ++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 23 +++++++++++++++++
mlir/test/Dialect/Tosa/invalid.mlir | 27 ++++++++++++++++++++
3 files changed, 52 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7871b46724a03d..0be0f8ef2d7a0c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -897,6 +897,8 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
let assemblyFormat = [{
$input `,` $table attr-dict `:` `(` type($input) `,` type($table) `)` `->` type($output)
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 39ea7a5b61f5ec..e42a5678ebc73a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -864,6 +864,29 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::TableOp::verify() {
+ TensorType inputType = getInput().getType();
+ TensorType outputType = getOutput().getType();
+
+ if (inputType.hasRank() && outputType.hasRank() &&
+ inputType.getRank() != outputType.getRank())
+ return emitOpError()
+ << "expected input tensor rank to equal result tensor rank";
+
+ auto inputDims = inputType.getShape();
+ auto outputDims = outputType.getShape();
+ for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
+ int64_t dim = it.index();
+ auto [inputDim, outputDim] = it.value();
+ if (outputDim != ShapedType::kDynamic && outputDim != inputDim) {
+ return emitOpError() << "dim(result, " << dim << ") = " << outputDim
+ << " doesn't match dim(input, " << dim
+ << ") = " << inputDim;
+ }
+ }
+ return success();
+}
+
LogicalResult tosa::TileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TileOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e1fcf056480083..e723aef3815ce6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -448,3 +448,30 @@ func.func @test_large_constant_permutation() {
%3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
return
}
+
+// -----
+
+// CHECK-LABEL: test_table_rank0_table
+func.func @test_table_rank0_table(%arg0: tensor<64xi16>, %arg1: tensor<i16>) {
+ // expected-error at +1 {{'tosa.table' op operand #1 must be 1-d tensor, but got 'tensor<i16>'}}
+ %0 = tosa.table %arg0, %arg1 : (tensor<64xi16>, tensor<i16>) -> tensor<64xi16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_table_io_rank_mismatch
+func.func @test_table_io_rank_mismatch(%arg0: tensor<64xi16>, %arg1: tensor<6xi16>) {
+ // expected-error at +1 {{'tosa.table' op expected input tensor rank to equal result tensor rank}}
+ %0 = tosa.table %arg0, %arg1 : (tensor<64xi16>, tensor<6xi16>) -> tensor<64x?xi16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: test_table_io_shape_mismatch
+func.func @test_table_io_shape_mismatch(%arg0: tensor<?x16xi16>, %arg1: tensor<6xi16>) {
+ // expected-error at +1 {{'tosa.table' op dim(result, 1) = 15 doesn't match dim(input, 1) = 16}}
+ %0 = tosa.table %arg0, %arg1 : (tensor<?x16xi16>, tensor<6xi16>) -> tensor<?x15xi16>
+ return
+}
More information about the Mlir-commits
mailing list