[Mlir-commits] [mlir] [mlir][tosa] Add verifier for `tosa.pad` (PR #106351)

Longsheng Mou llvmlistbot at llvm.org
Wed Aug 28 01:46:51 PDT 2024


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/106351

This patch adds verifier to `tosa.pad` which fixes a crash. `tosa.pad` expect:
- same input and output tensor rank.
- 'padding' tensor rank equal to 2.
Fix #106168.

>From 69d60a0e9ad8b4ba47e6d207e7e98f726d8525a9 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Wed, 28 Aug 2024 16:40:25 +0800
Subject: [PATCH] [mlir][tosa] Add verifier for `tosa.pad`

This patch adds verifier to `tosa.pad` which fixes a crash.
`tosa.pad` expect:
- same input and output tensor rank.
- 'padding' tensor rank equal to 2.
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td |  1 +
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp         | 14 +++++++++++
 mlir/test/Dialect/Tosa/invalid.mlir          | 25 ++++++++++++++++++++
 3 files changed, 40 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 0be0f8ef2d7a0c..1a132e73be8645 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1594,6 +1594,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c8e2b04eea0e22..6bba63db501e7c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -817,6 +817,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::PadOp::verify() {
+  RankedTensorType inputType = getInput1().getType();
+  RankedTensorType outputType = getOutput().getType();
+  TensorType paddingType = getPadding().getType();
+
+  if (inputType.getRank() != outputType.getRank())
+    return emitOpError() << "expect same input and output tensor rank.";
+
+  if (paddingType.hasRank() && paddingType.getRank() != 2)
+    return emitOpError() << "expect 'padding' tensor rank equal to 2.";
+
+  return success();
+}
+
 static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
   return to_vector(llvm::map_range(shape, [](int64_t dim) {
     return dim == -1 ? ShapedType::kDynamic : dim;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 806ba22e1bbe8c..d7067814e75d90 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -72,6 +72,31 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> t
 
 // -----
 
+func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
+  // expected-error at +1 {{'tosa.pad' op expect same input and output tensor rank.}}
+  %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2x2xi32>) -> tensor<13x21x3xf32>
+  return
+}
+
+// -----
+
+func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2xi32>) {
+  // expected-error at +1 {{'tosa.pad' op expect 'padding' tensor rank equal to 2.}}
+  %1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2xi32>) -> tensor<13x21xf32>
+  return
+}
+
+// -----
+
+func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
+  %0 = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+  // expected-error at +1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<1xf32>'}}
+  %1 = tosa.pad %arg0, %arg1, %0 : (tensor<13x21xf32>, tensor<2x2xi32>, tensor<1xf32>) -> tensor<13x21xf32>
+  return
+}
+
+// -----
+
 func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
   // expected-error at +1 {{'tosa.transpose' op perms of transpose is not constant}}
   %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>



More information about the Mlir-commits mailing list