[Mlir-commits] [mlir] [mlir][tosa] Add missing verifier for `tosa.pad` (PR #120934)
Longsheng Mou
llvmlistbot at llvm.org
Mon Dec 23 05:32:14 PST 2024
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/120934
>From b3424d55104ca02004733fa23fd74610b59b7d1b Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Mon, 23 Dec 2024 14:46:50 +0800
Subject: [PATCH] [mlir][tosa] Add missing verifier for `tosa.pad`
This PR adds a missing verifier for `tosa.pad`, ensuring that the padding shape matches [rank(input), 2].
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 13 ++++++++++---
mlir/test/Dialect/Tosa/invalid.mlir | 18 +++++++++++++++++-
3 files changed, 28 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d1629..9ca5f9f959c891 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1566,7 +1566,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
let arguments = (ins
Tosa_RankedTensor:$input1,
- Tosa_Int32Or64Tensor:$padding,
+ 2DTensorOf<[Tosa_Int32Or64]>:$padding,
Optional<Tosa_ScalarTensor>:$pad_const,
OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 631d3c48f2df02..ef495115e98381 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -823,13 +823,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
LogicalResult tosa::PadOp::verify() {
RankedTensorType inputType = getInput1().getType();
RankedTensorType outputType = getOutput().getType();
- TensorType paddingType = getPadding().getType();
+ RankedTensorType paddingType = getPadding().getType();
if (inputType.getRank() != outputType.getRank())
return emitOpError() << "expect same input and output tensor rank.";
- if (paddingType.hasRank() && paddingType.getRank() != 2)
- return emitOpError() << "expect 'padding' tensor rank equal to 2.";
+ if (!paddingType.isDynamicDim(0) &&
+ paddingType.getDimSize(0) != inputType.getRank())
+ return emitOpError() << "expected padding tensor dim 0 to have size "
+ << inputType.getRank() << " (input rank) but got size "
+ << paddingType.getDimSize(0);
+
+ if (!paddingType.isDynamicDim(1) && paddingType.getDimSize(1) != 2)
+ return emitOpError() << "expected padding tensor dim 1 to have size 2 "
+ << "but got size " << paddingType.getDimSize(1);
return success();
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index cca50b25d14d6b..f0c673b9f3b47c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -103,7 +103,7 @@ func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2
// -----
func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2xi32>) {
- // expected-error at +1 {{'tosa.pad' op expect 'padding' tensor rank equal to 2.}}
+ // expected-error at +1 {{'tosa.pad' op operand #1 must be 2D tensor of 32-bit signless integer or 64-bit signless integer values, but got 'tensor<2xi32>'}}
%1 = tosa.pad %arg0, %arg1 : (tensor<13x21xf32>, tensor<2xi32>) -> tensor<13x21xf32>
return
}
@@ -119,6 +119,22 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso
// -----
+func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x2xi32>) -> tensor<13x21x3xf32> {
+ // expected-error at +1 {{'tosa.pad' op expected padding tensor dim 0 to have size 3 (input rank) but got size 2}}
+ %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<2x2xi32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x1xi32>) -> tensor<13x21x3xf32> {
+ // expected-error at +1 {{'tosa.pad' op expected padding tensor dim 1 to have size 2 but got size 1}}
+ %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3x1xi32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
// expected-error at +1 {{'tosa.transpose' op perms of transpose is not constant}}
%0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
More information about the Mlir-commits
mailing list