[Mlir-commits] [mlir] [mlir][linalg] Reject unsigned pooling on non-integer element types (PR #166070)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 2 06:09:05 PST 2025


https://github.com/Men-cotton created https://github.com/llvm/llvm-project/pull/166070

#164800 

Ensures unsigned pooling ops in Linalg stay in the integer domain: the lowering now rejects floating/bool inputs with a clear diagnostic, new regression tests lock in both the error path and a valid integer example, and transform decompositions are updated to reflect the integer typing.

CC: @banach-space 

>From 03ef5fc57064198c3aa4424a722077ab94fbbda5 Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Sun, 2 Nov 2025 22:59:39 +0900
Subject: [PATCH] [mlir][linalg] Reject unsigned pooling on non-integer element
 types

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 18 +++++++---
 mlir/test/Dialect/Linalg/named-ops-fail.mlir  | 15 +++++++-
 mlir/test/Dialect/Linalg/named-ops.mlir       | 34 +++++++++++++++++++
 .../Linalg/transform-op-decompose.mlir        | 28 +++++++--------
 4 files changed, 76 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3dc45edf4a23f..8eb03dc182ae9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -579,13 +579,23 @@ class RegionBuilderHelper {
       return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
     case BinaryFn::max_unsigned:
       assert(!allComplex);
-      if (allFloatingPoint)
-        return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+      if (!allInteger || allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: unsigned max not on uint";
+          return nullptr;
+        }
+        llvm_unreachable("unsupported operation: unsigned max not on uint");
+      }
       return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
     case BinaryFn::min_unsigned:
       assert(!allComplex);
-      if (allFloatingPoint)
-        return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+      if (!allInteger || allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: unsigned min not on uint";
+          return nullptr;
+        }
+        llvm_unreachable("unsupported operation: unsigned min not on uint");
+      }
       return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
     case BinaryFn::powf:
       assert(allFloatingPoint);
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index 552a0abaa797c..4ecf685b4c695 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -80,6 +80,20 @@ func.func @divu_broadcast(%arg0: memref<8x16xi32>, %arg1: memref<4x8x16xi32>, %a
 
 // -----
 
+func.func @pooling_nhwc_max_unsigned_float(
+    %input: tensor<?x?x?x?xf32>,
+    %filter: tensor<?x?xf32>,
+    %init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  // CHECK: unsupported operation: unsigned max not on uint
+  linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+                                    strides = dense<1> : tensor<2xi64>}
+      ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+     outs (%init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %init_val : tensor<?x?x?x?xf32>
+}
+
+// -----
+
 func.func @exp_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
   // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
   linalg.exp ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
@@ -349,4 +363,3 @@ func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<
   linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
   return
 }
-
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index a93e9799ceb3f..c2a8f24624d8e 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -705,6 +705,23 @@ func.func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
   return %res : tensor<1x2x2x1xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_max_unsigned_tensor
+// CHECK:         %{{.+}} = linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_max_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+  %fake = tensor.empty() : tensor<3x3xi32>
+  %init = tensor.empty() : tensor<1x2x2x1xi32>
+  %cst = arith.constant 0 : i32
+  %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+  %res = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+    outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+  return %res : tensor<1x2x2x1xi32>
+}
+
 // -----
 // CHECK-LABEL: func @pooling_nwc_max_tensor
 // CHECK:         %{{.+}} = linalg.pooling_nwc_max
@@ -1017,6 +1034,23 @@ func.func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
 
 // -----
 
+// CHECK-LABEL: func @pooling_nhwc_min_unsigned_tensor
+// CHECK:         %{{.+}} = linalg.pooling_nhwc_min_unsigned
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_min_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+  %fake = tensor.empty() : tensor<3x3xi32>
+  %init = tensor.empty() : tensor<1x2x2x1xi32>
+  %cst = arith.constant 0 : i32
+  %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+  %res = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+    outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+  return %res : tensor<1x2x2x1xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @pooling_nwc_min_tensor
 // CHECK:         %{{.+}} = linalg.pooling_nwc_min
 // CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 72acf43361f50..60a4c555fa19a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -131,10 +131,10 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
 }
 
 // CHECK-LABEL: @pooling_nhwc_max_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
   // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
   // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
   // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -142,10 +142,10 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
   // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
-     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
-    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+     ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+    outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
   // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  return %0 : tensor<?x1x?x?xi32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_min
@@ -167,10 +167,10 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
 }
 
 // CHECK-LABEL: @pooling_nhwc_min_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
   // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
   // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
   // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -178,10 +178,10 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
   // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
-     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
-    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+     ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+    outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
   // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  return %0 : tensor<?x1x?x?xi32>
 }
 
 // CHECK-LABEL: @pooling_nchw_max



More information about the Mlir-commits mailing list