[Mlir-commits] [mlir] [mlir][linalg] Reject unsigned pooling on non-integer element types (PR #166070)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 2 06:09:53 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Men-cotton (Men-cotton)
<details>
<summary>Changes</summary>
#<!-- -->164800
Ensures unsigned pooling ops in Linalg stay in the integer domain: the lowering now rejects floating/bool inputs with a clear diagnostic, new regression tests lock in both the error path and a valid integer example, and transform decompositions are updated to reflect the integer typing.
CC: @<!-- -->banach-space
---
Full diff: https://github.com/llvm/llvm-project/pull/166070.diff
4 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+14-4)
- (modified) mlir/test/Dialect/Linalg/named-ops-fail.mlir (+14-1)
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+34)
- (modified) mlir/test/Dialect/Linalg/transform-op-decompose.mlir (+14-14)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3dc45edf4a23f..8eb03dc182ae9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -579,13 +579,23 @@ class RegionBuilderHelper {
return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::max_unsigned:
assert(!allComplex);
- if (allFloatingPoint)
- return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (!allInteger || allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: unsigned max not on uint";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported operation: unsigned max not on uint");
+ }
return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::min_unsigned:
assert(!allComplex);
- if (allFloatingPoint)
- return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (!allInteger || allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: unsigned min not on uint";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported operation: unsigned min not on uint");
+ }
return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::powf:
assert(allFloatingPoint);
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index 552a0abaa797c..4ecf685b4c695 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -80,6 +80,20 @@ func.func @divu_broadcast(%arg0: memref<8x16xi32>, %arg1: memref<4x8x16xi32>, %a
// -----
+func.func @pooling_nhwc_max_unsigned_float(
+ %input: tensor<?x?x?x?xf32>,
+ %filter: tensor<?x?xf32>,
+ %init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ // CHECK: unsupported operation: unsigned max not on uint
+ linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %init_val : tensor<?x?x?x?xf32>
+}
+
+// -----
+
func.func @exp_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
// CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
linalg.exp ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
@@ -349,4 +363,3 @@ func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<
linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
return
}
-
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index a93e9799ceb3f..c2a8f24624d8e 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -705,6 +705,23 @@ func.func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
return %res : tensor<1x2x2x1xf32>
}
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_max_unsigned_tensor
+// CHECK: %{{.+}} = linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_max_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+ %fake = tensor.empty() : tensor<3x3xi32>
+ %init = tensor.empty() : tensor<1x2x2x1xi32>
+ %cst = arith.constant 0 : i32
+ %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ %res = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+ outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ return %res : tensor<1x2x2x1xi32>
+}
+
// -----
// CHECK-LABEL: func @pooling_nwc_max_tensor
// CHECK: %{{.+}} = linalg.pooling_nwc_max
@@ -1017,6 +1034,23 @@ func.func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
// -----
+// CHECK-LABEL: func @pooling_nhwc_min_unsigned_tensor
+// CHECK: %{{.+}} = linalg.pooling_nhwc_min_unsigned
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_min_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+ %fake = tensor.empty() : tensor<3x3xi32>
+ %init = tensor.empty() : tensor<1x2x2x1xi32>
+ %cst = arith.constant 0 : i32
+ %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ %res = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+ outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ return %res : tensor<1x2x2x1xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @pooling_nwc_min_tensor
// CHECK: %{{.+}} = linalg.pooling_nwc_min
// CHECK-SAME: dilations = dense<1> : tensor<1xi64>
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 72acf43361f50..60a4c555fa19a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -131,10 +131,10 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
}
// CHECK-LABEL: @pooling_nhwc_max_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -142,10 +142,10 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
- outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+ outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
// CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ return %0 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nhwc_min
@@ -167,10 +167,10 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
}
// CHECK-LABEL: @pooling_nhwc_min_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -178,10 +178,10 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
- outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+ outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
// CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ return %0 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nchw_max
``````````
</details>
https://github.com/llvm/llvm-project/pull/166070
More information about the Mlir-commits
mailing list