[Mlir-commits] [mlir] ebb1c27 - [mlir][linalg] Reject unsigned pooling on non-integer element types (#166070)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 31 23:34:45 PST 2025
Author: Akimasa Watanuki
Date: 2026-01-01T13:04:41+05:30
New Revision: ebb1c27198bab15fc87ef1652f6dbc9b23b2754b
URL: https://github.com/llvm/llvm-project/commit/ebb1c27198bab15fc87ef1652f6dbc9b23b2754b
DIFF: https://github.com/llvm/llvm-project/commit/ebb1c27198bab15fc87ef1652f6dbc9b23b2754b.diff
LOG: [mlir][linalg] Reject unsigned pooling on non-integer element types (#166070)
Fixes: #164800
Ensures unsigned pooling ops in Linalg stay in the integer domain: the
lowering now rejects floating/bool inputs with a clear diagnostic, new
regression tests lock in both the error path and a valid integer
example, and transform decompositions are updated to reflect the integer
typing.
Signed-off-by: Akimasa Watanuki <mencotton0410 at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/named-ops-fail.mlir
mlir/test/Dialect/Linalg/named-ops.mlir
mlir/test/Dialect/Linalg/transform-op-decompose.mlir
mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 33ec79b1b4b1b..210f9584c1e86 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -579,13 +579,23 @@ class RegionBuilderHelper {
return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::max_unsigned:
assert(!allComplex);
- if (allFloatingPoint)
- return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (!allInteger || allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: unsigned max not on uint";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported operation: unsigned max not on uint");
+ }
return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::min_unsigned:
assert(!allComplex);
- if (allFloatingPoint)
- return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (!allInteger || allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: unsigned min not on uint";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported operation: unsigned min not on uint");
+ }
return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::powf:
assert(allFloatingPoint);
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 254458a978828..fb2570c7bb498 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -532,9 +532,9 @@ def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
- return arith.MaximumFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if (
+ _is_integer_type(lhs.type) and not _is_bool_type(lhs.type)
+ ) or _is_index_type(lhs.type):
return arith.MaxUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
@@ -546,9 +546,9 @@ def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
- return arith.MinimumFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if (
+ _is_integer_type(lhs.type) and not _is_bool_type(lhs.type)
+ ) or _is_index_type(lhs.type):
return arith.MinUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
@@ -634,6 +634,12 @@ def _is_index_type(t: Type) -> bool:
return IndexType.isinstance(t)
+def _is_bool_type(t: Type) -> bool:
+ if not IntegerType.isinstance(t):
+ return False
+ return IntegerType(t).width == 1
+
+
def _get_floating_point_width(t: Type) -> int:
# TODO: Create a FloatType in the Python API and implement the switch
# there.
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
index 1d01d2dad3105..809a4208f8db0 100644
--- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -481,19 +481,6 @@ func.func @pooling_nhwc_min_unsigned_integer(%input: tensor<?x?x?x?xi32>, %filte
// -----
-func.func @pooling_nhwc_min_unsigned_float(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
- %0 = linalg.pooling_nhwc_min_unsigned
- {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
- ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
- outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
- return %0 : tensor<?x?x?x?xf32>
-}
-// CHECK: @pooling_nhwc_min_unsigned_float
-// CHECK: linalg.pooling_nhwc_min
-// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
-
-// -----
-
func.func @pooling_nchw_sum(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
%0 = linalg.pooling_nchw_sum
{dilations = dense<2> : tensor<2xi64>, strides = dense<3> : tensor<2xi64>}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 1f554e6c45da7..5a699135604b7 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1931,6 +1931,125 @@ func.func @reduce_non_operation_name(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -
// -----
+//===----------------------------------------------------------------------===//
+// linalg.pooling_nhwc_*
+//===----------------------------------------------------------------------===//
+
+func.func @pooling_nhwc_max_unsigned_float_type(
+ %input: tensor<1x4x4x1xf32>,
+ %filter: tensor<2x2xf32>,
+ %init_val: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> {
+ // expected-error @+1 {{unsupported operation: unsigned max not on uint}}
+ %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x1xf32>, tensor<2x2xf32>)
+ outs (%init_val: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
+ return %0 : tensor<1x2x2x1xf32>
+}
+
+// -----
+
+func.func @pooling_nhwc_max_unsigned_i1(
+ %input: tensor<1x4x4x1xi1>,
+ %filter: tensor<2x2xi1>,
+ %init_val: tensor<1x2x2x1xi1>) -> tensor<1x2x2x1xi1> {
+ // expected-error @+1 {{unsupported operation: unsigned max not on uint}}
+ %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x1xi1>, tensor<2x2xi1>)
+ outs (%init_val: tensor<1x2x2x1xi1>) -> tensor<1x2x2x1xi1>
+ return %0 : tensor<1x2x2x1xi1>
+}
+
+// -----
+
+func.func @pooling_nhwc_min_unsigned_float_type(
+ %input: tensor<1x4x4x1xf32>,
+ %filter: tensor<2x2xf32>,
+ %init_val: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> {
+ // expected-error @+1 {{unsupported operation: unsigned min not on uint}}
+ %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x1xf32>, tensor<2x2xf32>)
+ outs (%init_val: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
+ return %0 : tensor<1x2x2x1xf32>
+}
+
+// -----
+
+func.func @pooling_nhwc_min_unsigned_i1(
+ %input: tensor<1x4x4x1xi1>,
+ %filter: tensor<2x2xi1>,
+ %init_val: tensor<1x2x2x1xi1>) -> tensor<1x2x2x1xi1> {
+ // expected-error @+1 {{unsupported operation: unsigned min not on uint}}
+ %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x1xi1>, tensor<2x2xi1>)
+ outs (%init_val: tensor<1x2x2x1xi1>) -> tensor<1x2x2x1xi1>
+ return %0 : tensor<1x2x2x1xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// linalg.pooling_nwc_*
+//===----------------------------------------------------------------------===//
+
+func.func @pooling_nwc_max_unsigned_float_type(
+ %input: tensor<1x4x1xf32>,
+ %filter: tensor<2xf32>,
+ %init_val: tensor<1x2x1xf32>) -> tensor<1x2x1xf32> {
+ // expected-error @+1 {{unsupported operation: unsigned max not on uint}}
+ %0 = linalg.pooling_nwc_max_unsigned {dilations = dense<1> : tensor<1xi64>,
+ strides = dense<1> : tensor<1xi64>}
+ ins (%input, %filter: tensor<1x4x1xf32>, tensor<2xf32>)
+ outs (%init_val: tensor<1x2x1xf32>) -> tensor<1x2x1xf32>
+ return %0 : tensor<1x2x1xf32>
+}
+
+// -----
+
+func.func @pooling_nwc_max_unsigned_i1(
+ %input: tensor<1x4x1xi1>,
+ %filter: tensor<2xi1>,
+ %init_val: tensor<1x2x1xi1>) -> tensor<1x2x1xi1> {
+ // expected-error @+1 {{unsupported operation: unsigned max not on uint}}
+ %0 = linalg.pooling_nwc_max_unsigned {dilations = dense<1> : tensor<1xi64>,
+ strides = dense<1> : tensor<1xi64>}
+ ins (%input, %filter: tensor<1x4x1xi1>, tensor<2xi1>)
+ outs (%init_val: tensor<1x2x1xi1>) -> tensor<1x2x1xi1>
+ return %0 : tensor<1x2x1xi1>
+}
+
+// -----
+
+func.func @pooling_nwc_min_unsigned_float_type(
+ %input: tensor<1x4x1xf32>,
+ %filter: tensor<2xf32>,
+ %init_val: tensor<1x2x1xf32>) -> tensor<1x2x1xf32> {
+ // expected-error @+1 {{unsupported operation: unsigned min not on uint}}
+ %0 = linalg.pooling_nwc_min_unsigned {dilations = dense<1> : tensor<1xi64>,
+ strides = dense<1> : tensor<1xi64>}
+ ins (%input, %filter: tensor<1x4x1xf32>, tensor<2xf32>)
+ outs (%init_val: tensor<1x2x1xf32>) -> tensor<1x2x1xf32>
+ return %0 : tensor<1x2x1xf32>
+}
+
+// -----
+
+func.func @pooling_nwc_min_unsigned_i1(
+ %input: tensor<1x4x1xi1>,
+ %filter: tensor<2xi1>,
+ %init_val: tensor<1x2x1xi1>) -> tensor<1x2x1xi1> {
+ // expected-error @+1 {{unsupported operation: unsigned min not on uint}}
+ %0 = linalg.pooling_nwc_min_unsigned {dilations = dense<1> : tensor<1xi64>,
+ strides = dense<1> : tensor<1xi64>}
+ ins (%input, %filter: tensor<1x4x1xi1>, tensor<2xi1>)
+ outs (%init_val: tensor<1x2x1xi1>) -> tensor<1x2x1xi1>
+ return %0 : tensor<1x2x1xi1>
+}
+
+// -----
//===----------------------------------------------------------------------===//
// Tests for generic infrastructure for named Ops. The actual Ops used are
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index 552a0abaa797c..665119d94e534 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -349,4 +349,3 @@ func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<
linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
return
}
-
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index a93e9799ceb3f..1e356c8fb4e72 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -705,6 +705,42 @@ func.func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
return %res : tensor<1x2x2x1xf32>
}
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_max_unsigned_i32
+// CHECK: %{{.+}} = linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_max_unsigned_i32(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+ %fake = tensor.empty() : tensor<3x3xi32>
+ %init = tensor.empty() : tensor<1x2x2x1xi32>
+ %cst = arith.constant 0 : i32
+ %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ %res = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+ outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ return %res : tensor<1x2x2x1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @pooling_nwc_max_unsigned_i32
+// CHECK: %{{.+}} = linalg.pooling_nwc_max_unsigned
+// CHECK-SAME: dilations = dense<1> : tensor<1xi64>
+// CHECK-SAME: strides = dense<1> : tensor<1xi64>
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x1xi32>, tensor<3xi32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x1xi32>) -> tensor<1x2x1xi32>
+func.func @pooling_nwc_max_unsigned_i32(%input: tensor<1x4x1xi32>) -> tensor<1x2x1xi32> {
+ %fake = tensor.empty() : tensor<3xi32>
+ %init = tensor.empty() : tensor<1x2x1xi32>
+ %cst = arith.constant 0 : i32
+ %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x1xi32>) -> tensor<1x2x1xi32>
+ %res = linalg.pooling_nwc_max_unsigned {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %fake: tensor<1x4x1xi32>, tensor<3xi32>)
+ outs(%fill: tensor<1x2x1xi32>) -> tensor<1x2x1xi32>
+ return %res : tensor<1x2x1xi32>
+}
+
// -----
// CHECK-LABEL: func @pooling_nwc_max_tensor
// CHECK: %{{.+}} = linalg.pooling_nwc_max
@@ -1017,6 +1053,42 @@ func.func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
// -----
+// CHECK-LABEL: func @pooling_nhwc_min_unsigned_i32
+// CHECK: %{{.+}} = linalg.pooling_nhwc_min_unsigned
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_min_unsigned_i32(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+ %fake = tensor.empty() : tensor<3x3xi32>
+ %init = tensor.empty() : tensor<1x2x2x1xi32>
+ %cst = arith.constant 0 : i32
+ %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ %res = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+ outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ return %res : tensor<1x2x2x1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @pooling_nwc_min_unsigned_i32
+// CHECK: %{{.+}} = linalg.pooling_nwc_min_unsigned
+// CHECK-SAME: dilations = dense<1> : tensor<1xi64>
+// CHECK-SAME: strides = dense<1> : tensor<1xi64>
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x1xi32>, tensor<3xi32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x1xi32>) -> tensor<1x2x1xi32>
+func.func @pooling_nwc_min_unsigned_i32(%input: tensor<1x4x1xi32>) -> tensor<1x2x1xi32> {
+ %fake = tensor.empty() : tensor<3xi32>
+ %init = tensor.empty() : tensor<1x2x1xi32>
+ %cst = arith.constant 0 : i32
+ %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x1xi32>) -> tensor<1x2x1xi32>
+ %res = linalg.pooling_nwc_min_unsigned {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %fake: tensor<1x4x1xi32>, tensor<3xi32>)
+ outs(%fill: tensor<1x2x1xi32>) -> tensor<1x2x1xi32>
+ return %res : tensor<1x2x1xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @pooling_nwc_min_tensor
// CHECK: %{{.+}} = linalg.pooling_nwc_min
// CHECK-SAME: dilations = dense<1> : tensor<1xi64>
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 72acf43361f50..60a4c555fa19a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -131,10 +131,10 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
}
// CHECK-LABEL: @pooling_nhwc_max_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -142,10 +142,10 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
- outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+ outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
// CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ return %0 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nhwc_min
@@ -167,10 +167,10 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
}
// CHECK-LABEL: @pooling_nhwc_min_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -178,10 +178,10 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
- outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+ outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
// CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ return %0 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nchw_max
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
index 4ce0fbc1dbe53..0df87de6393d8 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
@@ -150,3 +150,51 @@ def test_f32f32_min_pooling(input, shape, init_result):
print(module)
+
+with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f32 = F32Type.get()
+ bool_t = IntegerType.get_signless(1)
+
+ # CHECK: bool_max_unsigned_error: Unsupported 'max_unsigned' operands
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((1, 4, 16, 1), f32),
+ RankedTensorType.get((2, 2), f32),
+ RankedTensorType.get((1, 2, 4, 1), bool_t),
+ )
+ def test_bool_i1_max_unsigned_pooling_error(input, shape, init_result):
+ try:
+ pooling_poly(
+ input,
+ shape,
+ outs=[init_result],
+ reduce=BinaryFn.max_unsigned,
+ cast=TypeFn.cast_unsigned,
+ strides=[2, 4],
+ dilations=[1, 2],
+ )
+ except NotImplementedError as e:
+ print(f"bool_max_unsigned_error: {e}")
+ return init_result
+
+ # CHECK: float_max_unsigned_error: Unsupported 'max_unsigned' operands
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((1, 4, 16, 1), f32),
+ RankedTensorType.get((2, 2), f32),
+ RankedTensorType.get((1, 2, 4, 1), f32),
+ )
+ def test_f32f32_max_unsigned_pooling_error(input, shape, init_result):
+ try:
+ pooling_poly(
+ input,
+ shape,
+ outs=[init_result],
+ reduce=BinaryFn.max_unsigned,
+ cast=TypeFn.cast_unsigned,
+ strides=[2, 4],
+ dilations=[1, 2],
+ )
+ except NotImplementedError as e:
+ print(f"float_max_unsigned_error: {e}")
+ return init_result
More information about the Mlir-commits
mailing list