[Mlir-commits] [mlir] [mlir][tosa] Add canonicalization for adaptive to their non-adaptive variants (PR #195865)
Luke Hutton
llvmlistbot at llvm.org
Tue May 5 07:37:18 PDT 2026
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/195865
This commit adds canonicalization patterns to convert adaptive pooling (max and avg) to their non-adaptive variants when their CTC inputs are constants.
This is beneficial for backends that do not support the adaptive op variants.
>From 890149d3de748710b4991a8f7c7b0480a1b526e4 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 5 May 2026 14:41:45 +0100
Subject: [PATCH] [mlir][tosa] Add canonicalization for adaptive to their
non-adaptive variants
This commit adds canonicalization patterns to convert adaptive pooling
(max and avg) to their non-adaptive variants when their CTC inputs are
constants.
This is beneficial for backends that do not support the adaptive op
variants.
Change-Id: I9037438325a3b0071f14ebed0aa444acf66656df
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 4 +-
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 60 +++++++++++++++++++
mlir/test/Dialect/Tosa/canonicalize.mlir | 58 ++++++++++++++++++
3 files changed, 121 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 207618adc1352..71122ba8531c6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -161,6 +161,7 @@ def Tosa_AvgPool2dAdaptiveOp
}];
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
@@ -536,7 +537,7 @@ def Tosa_MaxPool2dAdaptiveOp
This performs a max pooling over the given input tensor. A sliding window of
size given by <kernel size> is passed over the input tensor, with the
maximum value being placed in the output tensor.
- Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride,
+ Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride,
pad arguments as inputs rather than attributes.
}];
@@ -557,6 +558,7 @@ def Tosa_MaxPool2dAdaptiveOp
];
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 1c186cd3ae122..642ee4b98e216 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -245,6 +245,36 @@ void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
context);
}
+struct AvgPool2dAdaptiveToAvgPool2d
+ : public OpRewritePattern<tosa::AvgPool2dAdaptiveOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op,
+ PatternRewriter &rewriter) const override {
+ llvm::SmallVector<int64_t> kernel;
+ llvm::SmallVector<int64_t> stride;
+ llvm::SmallVector<int64_t> pad;
+ if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
+ !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
+ !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
+ return rewriter.notifyMatchFailure(
+ op, "expected constant kernel, stride, and pad operands");
+
+ auto replacement = tosa::AvgPool2dOp::create(
+ rewriter, op.getLoc(), op.getType(), op.getInput(), op.getInputZp(),
+ op.getOutputZp(), rewriter.getDenseI64ArrayAttr(kernel),
+ rewriter.getDenseI64ArrayAttr(stride),
+ rewriter.getDenseI64ArrayAttr(pad), op.getAccTypeAttr());
+ rewriter.replaceOp(op, replacement.getOutput());
+ return success();
+ }
+};
+
+void AvgPool2dAdaptiveOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<AvgPool2dAdaptiveToAvgPool2d>(context);
+}
+
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
using OpRewritePattern::OpRewritePattern;
@@ -283,6 +313,36 @@ void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
context);
}
+struct MaxPool2dAdaptiveToMaxPool2d
+ : public OpRewritePattern<tosa::MaxPool2dAdaptiveOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op,
+ PatternRewriter &rewriter) const override {
+ llvm::SmallVector<int64_t> kernel;
+ llvm::SmallVector<int64_t> stride;
+ llvm::SmallVector<int64_t> pad;
+ if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
+ !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
+ !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
+ return rewriter.notifyMatchFailure(
+ op, "expected constant kernel, stride, and pad operands");
+
+ auto replacement = tosa::MaxPool2dOp::create(
+ rewriter, op.getLoc(), op.getType(), op.getInput(),
+ rewriter.getDenseI64ArrayAttr(kernel),
+ rewriter.getDenseI64ArrayAttr(stride),
+ rewriter.getDenseI64ArrayAttr(pad), op.getNanModeAttr());
+ rewriter.replaceOp(op, replacement.getOutput());
+ return success();
+ }
+};
+
+void MaxPool2dAdaptiveOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<MaxPool2dAdaptiveToMaxPool2d>(context);
+}
+
//===----------------------------------------------------------------------===//
// Data Layout / Memory Reinterpretation.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 9a7fa3efc8d3c..19583e111ebef 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1649,3 +1649,61 @@ func.func @test_do_not_canonicalize_cast_from_cast_to_block_scaled_unranked(%arg
%1, %2 = tosa.cast_to_block_scaled %0 {block_size = BLOCK_SIZE_32} : (tensor<*xf32>) -> (tensor<*xf6E2M3FN>, tensor<*xf8E8M0FNU>)
return %1, %2 : tensor<*xf6E2M3FN>, tensor<*xf8E8M0FNU>
}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_max_pool2d_adaptive
+// CHECK: %[[POOL:.+]] = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, nan_mode = IGNORE, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+// CHECK: return %[[POOL]]
+func.func @canonicalize_max_pool2d_adaptive(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+ %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad {nan_mode = IGNORE} :
+ (tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_avg_pool2d_adaptive
+// CHECK: %[[POOL:.+]] = tosa.avg_pool2d %arg0, %{{.*}}, %{{.*}} {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
+// CHECK: return %[[POOL]]
+func.func @canonicalize_avg_pool2d_adaptive(%arg0: tensor<1x7x7x9xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
+ %kernel = tosa.const_shape {values = dense<[3, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[1, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %0 = tosa.avg_pool2d_adaptive %arg0, %input_zp, %output_zp, %kernel, %stride, %pad {acc_type = f32} :
+ (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x7x7x9xf32>
+ return %0 : tensor<1x7x7x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_non_const_max_pool2d_adaptive
+// CHECK: tosa.max_pool2d_adaptive
+func.func @dont_canonicalize_non_const_max_pool2d_adaptive(%arg0: tensor<1x?x?x8xf32>) -> tensor<1x?x?x8xf32> {
+ %dim1 = tosa.dim %arg0 {axis = 1 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+ %dim2 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+ %kernel = tosa.concat_shape %dim1, %dim2 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad {nan_mode = IGNORE} :
+ (tensor<1x?x?x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x?x?x8xf32>
+ return %0 : tensor<1x?x?x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_non_const_avg_pool2d_adaptive
+// CHECK: tosa.avg_pool2d_adaptive
+func.func @dont_canonicalize_non_const_avg_pool2d_adaptive(%arg0: tensor<1x?x?x8xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x?x?x8xf32> {
+ %dim1 = tosa.dim %arg0 {axis = 1 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+ %dim2 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+ %kernel = tosa.concat_shape %dim1, %dim2 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %0 = tosa.avg_pool2d_adaptive %arg0, %input_zp, %output_zp, %kernel, %stride, %pad {acc_type = f32} :
+ (tensor<1x?x?x8xf32>, tensor<1xf32>, tensor<1xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x?x?x8xf32>
+ return %0 : tensor<1x?x?x8xf32>
+}
More information about the Mlir-commits
mailing list