[Mlir-commits] [mlir] [mlir][tosa] Stop folding pad into avg_pool2d (PR #164599)
Vitalii Shutov
llvmlistbot at llvm.org
Thu Oct 23 04:36:58 PDT 2025
https://github.com/Lallapallooza updated https://github.com/llvm/llvm-project/pull/164599
>From 44518644d3abd56d7d9f02e64df5228e7c969ac0 Mon Sep 17 00:00:00 2001
From: Vitalii Shutov <vitalii.shutov at arm.com>
Date: Tue, 21 Oct 2025 07:49:21 +0100
Subject: [PATCH] [mlir][tosa] Stop folding pad into avg_pool2d
Keep explicit padding ahead of tosa.avg_pool2d to preserve semantics.
Folding a pad into the op drops padded values from the average divisor.
Change-Id: I229bbdc0a8ef5d4ff4c6942788614c55593ce30f
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 -
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 29 -------------------
mlir/test/Dialect/Tosa/canonicalize.mlir | 8 ++---
3 files changed, 4 insertions(+), 34 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 697a04e94441a..137554f49460d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -108,7 +108,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
LogicalResult verifyOutputZeroPoint(int64_t zp);
}];
- let hasCanonicalizer = 1;
let hasVerifier = 1;
let assemblyFormat =
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index caf80165fc640..6f32f601e5be0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -75,28 +75,6 @@ namespace {
template <typename OpTy>
struct PoolPadFoldAdaptor;
-template <>
-struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
- using OpTy = tosa::AvgPool2dOp;
- static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
- const llvm::ArrayRef<int64_t> kernel = op.getKernel();
- if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
- newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
- return false;
- return true;
- }
- static bool checkPadConstCompliance(OpTy op, Value padConst) {
- return checkMatchingPadConstAndZp(padConst, op.getInputZp());
- }
- static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
- Value padInput, ArrayRef<int64_t> newPad) {
- rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
- op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
- op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
- op.getAccType());
- }
-};
-
template <>
struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
using OpTy = tosa::MaxPool2dOp;
@@ -245,13 +223,6 @@ struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
};
} // namespace
-void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
- PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
- context);
-}
-
void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e8525a5d2ed62..45d942bb92d6c 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -9,11 +9,11 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
// -----
-// CHECK-LABEL: @pad_wh_avg_pool2d_fold
-func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
- // CHECK-NOT: tosa.pad
+// CHECK-LABEL: @pad_wh_avg_pool2d_nofold
+func.func @pad_wh_avg_pool2d_nofold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
+ // CHECK: tosa.pad
// CHECK: tosa.avg_pool2d
- // CHECK-SAME: pad = array<i64: 1, 1, 1, 1>
+ // CHECK-SAME: pad = array<i64: 0, 1, 0, 1>
%pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
%pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
More information about the Mlir-commits
mailing list