[Mlir-commits] [mlir] 1e34706 - [mlir][tosa] Add verifier for `tosa.table` (#103708)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 14 23:30:03 PDT 2024


Author: Longsheng Mou
Date: 2024-08-15T14:30:00+08:00
New Revision: 1e34706232e5f2865ff918ba8e9f840f38cdef07

URL: https://github.com/llvm/llvm-project/commit/1e34706232e5f2865ff918ba8e9f840f38cdef07
DIFF: https://github.com/llvm/llvm-project/commit/1e34706232e5f2865ff918ba8e9f840f38cdef07.diff

LOG: [mlir][tosa] Add verifier for `tosa.table` (#103708)

This patch adds a verifier to `tosa.table` which fixes a crash. Fix
#103086.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7871b46724a03d..0be0f8ef2d7a0c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -897,6 +897,8 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
   let assemblyFormat = [{
     $input `,` $table attr-dict `:` `(` type($input) `,` type($table) `)` `->` type($output)
   }];
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 39ea7a5b61f5ec..d4e49b6e3c044c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -864,6 +864,29 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::TableOp::verify() {
+  TensorType inputType = getInput().getType();
+  TensorType outputType = getOutput().getType();
+
+  if (inputType.hasRank() && outputType.hasRank() &&
+      inputType.getRank() != outputType.getRank())
+    return emitOpError()
+           << "expected input tensor rank to equal result tensor rank";
+
+  auto inputDims = inputType.getShape();
+  auto outputDims = outputType.getShape();
+  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
+    int64_t dim = it.index();
+    auto [inputDim, outputDim] = it.value();
+    if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
+      return emitOpError() << "dim(result, " << dim << ") = " << outputDim
+                           << " doesn't match dim(input, " << dim
+                           << ") = " << inputDim;
+    }
+  }
+  return success();
+}
+
 LogicalResult tosa::TileOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TileOp::Adaptor adaptor,

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e1fcf056480083..e723aef3815ce6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -448,3 +448,30 @@ func.func @test_large_constant_permutation() {
   %3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
   return
 }
+
+// -----
+
+// CHECK-LABEL: test_table_rank0_table
+func.func @test_table_rank0_table(%arg0: tensor<64xi16>, %arg1: tensor<i16>) {
+  // expected-error at +1 {{'tosa.table' op operand #1 must be 1-d tensor, but got 'tensor<i16>'}}
+  %0 = tosa.table %arg0, %arg1 : (tensor<64xi16>, tensor<i16>) -> tensor<64xi16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_table_io_rank_mismatch
+func.func @test_table_io_rank_mismatch(%arg0: tensor<64xi16>, %arg1: tensor<6xi16>) {
+  // expected-error at +1 {{'tosa.table' op expected input tensor rank to equal result tensor rank}}
+  %0 = tosa.table %arg0, %arg1 : (tensor<64xi16>, tensor<6xi16>) -> tensor<64x?xi16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_table_io_shape_mismatch
+func.func @test_table_io_shape_mismatch(%arg0: tensor<?x16xi16>, %arg1: tensor<6xi16>) {
+  // expected-error at +1 {{'tosa.table' op dim(result, 1) = 15 doesn't match dim(input, 1) = 16}}
+  %0 = tosa.table %arg0, %arg1 : (tensor<?x16xi16>, tensor<6xi16>) -> tensor<?x15xi16>
+  return
+}


        


More information about the Mlir-commits mailing list