[Mlir-commits] [mlir] [mlir][tosa] Add canonicalization for adaptive to their non-adaptive variants (PR #195865)

Luke Hutton llvmlistbot at llvm.org
Tue May 5 07:37:18 PDT 2026


https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/195865

This commit adds canonicalization patterns to convert adaptive pooling (max and avg) to their non-adaptive variants when their CTC inputs are constants.

This is beneficial for backends that do not support the adaptive op variants.

>From 890149d3de748710b4991a8f7c7b0480a1b526e4 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 5 May 2026 14:41:45 +0100
Subject: [PATCH] [mlir][tosa] Add canonicalization for adaptive to their
 non-adaptive variants

This commit adds canonicalization patterns to convert adaptive pooling
(max and avg) to their non-adaptive variants when their CTC inputs are
constants.

This is beneficial for backends that do not support the adaptive op
variants.

Change-Id: I9037438325a3b0071f14ebed0aa444acf66656df
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  4 +-
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 60 +++++++++++++++++++
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 58 ++++++++++++++++++
 3 files changed, 121 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 207618adc1352..71122ba8531c6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -161,6 +161,7 @@ def Tosa_AvgPool2dAdaptiveOp
   }];
 
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 
   let assemblyFormat =
       "operands attr-dict `:` functional-type(operands, results)";
@@ -536,7 +537,7 @@ def Tosa_MaxPool2dAdaptiveOp
     This performs a max pooling over the given input tensor. A sliding window of
     size given by <kernel size> is passed over the input tensor, with the
     maximum value being placed in the output tensor.
-    Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride, 
+    Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride,
     pad arguments as inputs rather than attributes.
   }];
 
@@ -557,6 +558,7 @@ def Tosa_MaxPool2dAdaptiveOp
   ];
 
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;
 }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 1c186cd3ae122..642ee4b98e216 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -245,6 +245,36 @@ void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
       context);
 }
 
+struct AvgPool2dAdaptiveToAvgPool2d
+    : public OpRewritePattern<tosa::AvgPool2dAdaptiveOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op,
+                                PatternRewriter &rewriter) const override {
+    llvm::SmallVector<int64_t> kernel;
+    llvm::SmallVector<int64_t> stride;
+    llvm::SmallVector<int64_t> pad;
+    if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
+        !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
+        !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
+      return rewriter.notifyMatchFailure(
+          op, "expected constant kernel, stride, and pad operands");
+
+    auto replacement = tosa::AvgPool2dOp::create(
+        rewriter, op.getLoc(), op.getType(), op.getInput(), op.getInputZp(),
+        op.getOutputZp(), rewriter.getDenseI64ArrayAttr(kernel),
+        rewriter.getDenseI64ArrayAttr(stride),
+        rewriter.getDenseI64ArrayAttr(pad), op.getAccTypeAttr());
+    rewriter.replaceOp(op, replacement.getOutput());
+    return success();
+  }
+};
+
+void AvgPool2dAdaptiveOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<AvgPool2dAdaptiveToAvgPool2d>(context);
+}
+
 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -283,6 +313,36 @@ void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
       context);
 }
 
+struct MaxPool2dAdaptiveToMaxPool2d
+    : public OpRewritePattern<tosa::MaxPool2dAdaptiveOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op,
+                                PatternRewriter &rewriter) const override {
+    llvm::SmallVector<int64_t> kernel;
+    llvm::SmallVector<int64_t> stride;
+    llvm::SmallVector<int64_t> pad;
+    if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
+        !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
+        !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
+      return rewriter.notifyMatchFailure(
+          op, "expected constant kernel, stride, and pad operands");
+
+    auto replacement = tosa::MaxPool2dOp::create(
+        rewriter, op.getLoc(), op.getType(), op.getInput(),
+        rewriter.getDenseI64ArrayAttr(kernel),
+        rewriter.getDenseI64ArrayAttr(stride),
+        rewriter.getDenseI64ArrayAttr(pad), op.getNanModeAttr());
+    rewriter.replaceOp(op, replacement.getOutput());
+    return success();
+  }
+};
+
+void MaxPool2dAdaptiveOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<MaxPool2dAdaptiveToMaxPool2d>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Data Layout / Memory Reinterpretation.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 9a7fa3efc8d3c..19583e111ebef 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1649,3 +1649,61 @@ func.func @test_do_not_canonicalize_cast_from_cast_to_block_scaled_unranked(%arg
   %1, %2 = tosa.cast_to_block_scaled %0 {block_size = BLOCK_SIZE_32} : (tensor<*xf32>) -> (tensor<*xf6E2M3FN>, tensor<*xf8E8M0FNU>)
   return %1, %2 : tensor<*xf6E2M3FN>, tensor<*xf8E8M0FNU>
 }
+
+// -----
+
+// CHECK-LABEL: @canonicalize_max_pool2d_adaptive
+// CHECK: %[[POOL:.+]] = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, nan_mode = IGNORE, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+// CHECK: return %[[POOL]]
+func.func @canonicalize_max_pool2d_adaptive(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad {nan_mode = IGNORE} :
+         (tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_avg_pool2d_adaptive
+// CHECK: %[[POOL:.+]] = tosa.avg_pool2d %arg0, %{{.*}}, %{{.*}} {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
+// CHECK: return %[[POOL]]
+func.func @canonicalize_avg_pool2d_adaptive(%arg0: tensor<1x7x7x9xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
+  %kernel = tosa.const_shape {values = dense<[3, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[1, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.avg_pool2d_adaptive %arg0, %input_zp, %output_zp, %kernel, %stride, %pad {acc_type = f32} :
+         (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x7x7x9xf32>
+  return %0 : tensor<1x7x7x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_non_const_max_pool2d_adaptive
+// CHECK: tosa.max_pool2d_adaptive
+func.func @dont_canonicalize_non_const_max_pool2d_adaptive(%arg0: tensor<1x?x?x8xf32>) -> tensor<1x?x?x8xf32> {
+  %dim1 = tosa.dim %arg0 {axis = 1 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+  %dim2 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+  %kernel = tosa.concat_shape %dim1, %dim2 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad {nan_mode = IGNORE} :
+          (tensor<1x?x?x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x?x?x8xf32>
+  return %0 : tensor<1x?x?x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_non_const_avg_pool2d_adaptive
+// CHECK: tosa.avg_pool2d_adaptive
+func.func @dont_canonicalize_non_const_avg_pool2d_adaptive(%arg0: tensor<1x?x?x8xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x?x?x8xf32> {
+  %dim1 = tosa.dim %arg0 {axis = 1 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+  %dim2 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+  %kernel = tosa.concat_shape %dim1, %dim2 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.avg_pool2d_adaptive %arg0, %input_zp, %output_zp, %kernel, %stride, %pad {acc_type = f32} :
+          (tensor<1x?x?x8xf32>, tensor<1xf32>, tensor<1xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x?x?x8xf32>
+  return %0 : tensor<1x?x?x8xf32>
+}



More information about the Mlir-commits mailing list