[Mlir-commits] [mlir] [mlir][linalg] Reject unsigned pooling on non-integer element types (PR #166070)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 2 06:09:05 PST 2025
https://github.com/Men-cotton created https://github.com/llvm/llvm-project/pull/166070
#164800
Ensures unsigned pooling ops in Linalg stay in the integer domain: the lowering now rejects floating/bool inputs with a clear diagnostic, new regression tests lock in both the error path and a valid integer example, and transform decompositions are updated to reflect the integer typing.
CC: @banach-space
>From 03ef5fc57064198c3aa4424a722077ab94fbbda5 Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Sun, 2 Nov 2025 22:59:39 +0900
Subject: [PATCH] [mlir][linalg] Reject unsigned pooling on non-integer element
types
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 18 +++++++---
mlir/test/Dialect/Linalg/named-ops-fail.mlir | 15 +++++++-
mlir/test/Dialect/Linalg/named-ops.mlir | 34 +++++++++++++++++++
.../Linalg/transform-op-decompose.mlir | 28 +++++++--------
4 files changed, 76 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3dc45edf4a23f..8eb03dc182ae9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -579,13 +579,23 @@ class RegionBuilderHelper {
return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::max_unsigned:
assert(!allComplex);
- if (allFloatingPoint)
- return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (!allInteger || allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: unsigned max not on uint";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported operation: unsigned max not on uint");
+ }
return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::min_unsigned:
assert(!allComplex);
- if (allFloatingPoint)
- return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (!allInteger || allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: unsigned min not on uint";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported operation: unsigned min not on uint");
+ }
return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::powf:
assert(allFloatingPoint);
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index 552a0abaa797c..4ecf685b4c695 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -80,6 +80,20 @@ func.func @divu_broadcast(%arg0: memref<8x16xi32>, %arg1: memref<4x8x16xi32>, %a
// -----
+func.func @pooling_nhwc_max_unsigned_float(
+ %input: tensor<?x?x?x?xf32>,
+ %filter: tensor<?x?xf32>,
+ %init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ // CHECK: unsupported operation: unsigned max not on uint
+ linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %init_val : tensor<?x?x?x?xf32>
+}
+
+// -----
+
func.func @exp_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
// CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
linalg.exp ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
@@ -349,4 +363,3 @@ func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<
linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
return
}
-
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index a93e9799ceb3f..c2a8f24624d8e 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -705,6 +705,23 @@ func.func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
return %res : tensor<1x2x2x1xf32>
}
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_max_unsigned_tensor
+// CHECK: %{{.+}} = linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_max_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+ %fake = tensor.empty() : tensor<3x3xi32>
+ %init = tensor.empty() : tensor<1x2x2x1xi32>
+ %cst = arith.constant 0 : i32
+ %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ %res = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+ outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ return %res : tensor<1x2x2x1xi32>
+}
+
// -----
// CHECK-LABEL: func @pooling_nwc_max_tensor
// CHECK: %{{.+}} = linalg.pooling_nwc_max
@@ -1017,6 +1034,23 @@ func.func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
// -----
+// CHECK-LABEL: func @pooling_nhwc_min_unsigned_tensor
+// CHECK: %{{.+}} = linalg.pooling_nhwc_min_unsigned
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_min_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+ %fake = tensor.empty() : tensor<3x3xi32>
+ %init = tensor.empty() : tensor<1x2x2x1xi32>
+ %cst = arith.constant 0 : i32
+ %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ %res = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+ outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ return %res : tensor<1x2x2x1xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @pooling_nwc_min_tensor
// CHECK: %{{.+}} = linalg.pooling_nwc_min
// CHECK-SAME: dilations = dense<1> : tensor<1xi64>
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 72acf43361f50..60a4c555fa19a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -131,10 +131,10 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
}
// CHECK-LABEL: @pooling_nhwc_max_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -142,10 +142,10 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
- outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+ outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
// CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ return %0 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nhwc_min
@@ -167,10 +167,10 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
}
// CHECK-LABEL: @pooling_nhwc_min_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -178,10 +178,10 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
- outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+ outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
// CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ return %0 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nchw_max
More information about the Mlir-commits
mailing list