[Mlir-commits] [mlir] [mlir][tosa] Add verifier for `tosa.table` (PR #103708)

Longsheng Mou llvmlistbot at llvm.org
Wed Aug 14 00:26:32 PDT 2024


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/103708

This patch adds a verifier to `tosa.table` which fixes a crash. Fix #103086.

>From e42361ad83eded1c5168ad7679353a935904fcf4 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Wed, 14 Aug 2024 15:22:22 +0800
Subject: [PATCH] [mlir][tosa] Add verifier for `tosa.table`

This patch adds a verifier to `tosa.table` which fixes a crash.
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td |  2 ++
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp         | 23 +++++++++++++++++
 mlir/test/Dialect/Tosa/invalid.mlir          | 27 ++++++++++++++++++++
 3 files changed, 52 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7871b46724a03d..0be0f8ef2d7a0c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -897,6 +897,8 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
   let assemblyFormat = [{
     $input `,` $table attr-dict `:` `(` type($input) `,` type($table) `)` `->` type($output)
   }];
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 39ea7a5b61f5ec..e42a5678ebc73a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -864,6 +864,29 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::TableOp::verify() {
+  TensorType inputType = getInput().getType();
+  TensorType outputType = getOutput().getType();
+
+  if (inputType.hasRank() && outputType.hasRank() &&
+      inputType.getRank() != outputType.getRank())
+    return emitOpError()
+           << "expected input tensor rank to equal result tensor rank";
+
+  auto inputDims = inputType.getShape();
+  auto outputDims = outputType.getShape();
+  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
+    int64_t dim = it.index();
+    auto [inputDim, outputDim] = it.value();
+    if (outputDim != ShapedType::kDynamic && outputDim != inputDim) {
+      return emitOpError() << "dim(result, " << dim << ") = " << outputDim
+                           << " doesn't match dim(input, " << dim
+                           << ") = " << inputDim;
+    }
+  }
+  return success();
+}
+
 LogicalResult tosa::TileOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TileOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e1fcf056480083..e723aef3815ce6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -448,3 +448,30 @@ func.func @test_large_constant_permutation() {
   %3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
   return
 }
+
+// -----
+
+// CHECK-LABEL: test_table_rank0_table
+func.func @test_table_rank0_table(%arg0: tensor<64xi16>, %arg1: tensor<i16>) {
+  // expected-error at +1 {{'tosa.table' op operand #1 must be 1-d tensor, but got 'tensor<i16>'}}
+  %0 = tosa.table %arg0, %arg1 : (tensor<64xi16>, tensor<i16>) -> tensor<64xi16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_table_io_rank_mismatch
+func.func @test_table_io_rank_mismatch(%arg0: tensor<64xi16>, %arg1: tensor<6xi16>) {
+  // expected-error at +1 {{'tosa.table' op expected input tensor rank to equal result tensor rank}}
+  %0 = tosa.table %arg0, %arg1 : (tensor<64xi16>, tensor<6xi16>) -> tensor<64x?xi16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_table_io_shape_mismatch
+func.func @test_table_io_shape_mismatch(%arg0: tensor<?x16xi16>, %arg1: tensor<6xi16>) {
+  // expected-error at +1 {{'tosa.table' op dim(result, 1) = 15 doesn't match dim(input, 1) = 16}}
+  %0 = tosa.table %arg0, %arg1 : (tensor<?x16xi16>, tensor<6xi16>) -> tensor<?x15xi16>
+  return
+}



More information about the Mlir-commits mailing list