[Mlir-commits] [mlir] [mlir][affine] Support signless types for max/min/si/ui reductions. (PR #189480)
Slava Zakharin
llvmlistbot at llvm.org
Mon Mar 30 13:59:26 PDT 2026
https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/189480
As long as the reduction kinds for max/min/si/ui reductions
define the interpretation of the integer arguments, we can safely
support signless types. Moreover, affine-to-standard conversion
does not even work for non-signless types, because
`mlir::arith::getIdentityValue` generates invalid identity values.
>From 88cd7567ae89f85532ac2301f673a7638f6664f1 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Mon, 30 Mar 2026 13:45:02 -0700
Subject: [PATCH] [mlir][affine] Support signless types for max/min/si/ui
reductions.
As long as the reduction kinds for max/min/si/ui reductions
define the interpretation of the integer arguments, we can safely
support signless types. Moreover, affine-to-standard conversion
does not even work for non-signless types, because
`mlir::arith::getIdentityValue` generates invalid identity values.
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 8 +-
.../AffineToStandard/lower-affine.mlir | 98 +++++++++++++++++++
mlir/test/Dialect/Affine/ops.mlir | 59 +++++++++++
3 files changed, 161 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 839d34b41cbd4..18a7ffb24b2bc 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4218,19 +4218,19 @@ static bool isResultTypeMatchAtomicRMWKind(Type resultType,
return isa<FloatType>(resultType);
case arith::AtomicRMWKind::maxs: {
auto intType = dyn_cast<IntegerType>(resultType);
- return intType && intType.isSigned();
+ return intType && (intType.isSigned() || intType.isSignless());
}
case arith::AtomicRMWKind::mins: {
auto intType = dyn_cast<IntegerType>(resultType);
- return intType && intType.isSigned();
+ return intType && (intType.isSigned() || intType.isSignless());
}
case arith::AtomicRMWKind::maxu: {
auto intType = dyn_cast<IntegerType>(resultType);
- return intType && intType.isUnsigned();
+ return intType && (intType.isUnsigned() || intType.isSignless());
}
case arith::AtomicRMWKind::minu: {
auto intType = dyn_cast<IntegerType>(resultType);
- return intType && intType.isUnsigned();
+ return intType && (intType.isUnsigned() || intType.isSignless());
}
case arith::AtomicRMWKind::ori:
case arith::AtomicRMWKind::andi:
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 550ea71882e14..2013c2899ca70 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -927,3 +927,101 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
// CHECK: scf.reduce.return %[[RES]] : i64
// CHECK: }
// CHECK: }
+
+/////////////////////////////////////////////////////////////////////
+
+// CHECK-LABEL: func.func @parallel_maxsi_reduce(
+// CHECK-SAME: %[[ARG0:.*]]: memref<100xi32>) -> i32 {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : index
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 100 : index
+// CHECK: %[[CONSTANT_2:.*]] = arith.constant 1 : index
+// CHECK: %[[CONSTANT_3:.*]] = arith.constant -2147483648 : i32
+// CHECK: %[[PARALLEL_0:.*]] = scf.parallel (%[[VAL_0:.*]]) = (%[[CONSTANT_0]]) to (%[[CONSTANT_1]]) step (%[[CONSTANT_2]]) init (%[[CONSTANT_3]]) -> i32 {
+// CHECK: %[[LOAD_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_0]]] : memref<100xi32>
+// CHECK: scf.reduce(%[[LOAD_0]] : i32) {
+// CHECK: ^bb0(%[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32):
+// CHECK: %[[MAXSI_0:.*]] = arith.maxsi %[[VAL_1]], %[[VAL_2]] : i32
+// CHECK: scf.reduce.return %[[MAXSI_0]] : i32
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[PARALLEL_0]] : i32
+// CHECK: }
+func.func @parallel_maxsi_reduce(%arg0: memref<100xi32>) -> (i32) {
+ %12 = affine.parallel (%i) = (0) to (100) reduce ("maxs") -> (i32) {
+ %2 = affine.load %arg0[%i] : memref<100xi32>
+ affine.yield %2 : i32
+ }
+ return %12 : i32
+}
+
+// CHECK-LABEL: func.func @parallel_minsi_reduce(
+// CHECK-SAME: %[[ARG0:.*]]: memref<100xi32>) -> i32 {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : index
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 100 : index
+// CHECK: %[[CONSTANT_2:.*]] = arith.constant 1 : index
+// CHECK: %[[CONSTANT_3:.*]] = arith.constant 2147483647 : i32
+// CHECK: %[[PARALLEL_0:.*]] = scf.parallel (%[[VAL_0:.*]]) = (%[[CONSTANT_0]]) to (%[[CONSTANT_1]]) step (%[[CONSTANT_2]]) init (%[[CONSTANT_3]]) -> i32 {
+// CHECK: %[[LOAD_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_0]]] : memref<100xi32>
+// CHECK: scf.reduce(%[[LOAD_0]] : i32) {
+// CHECK: ^bb0(%[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32):
+// CHECK: %[[MINSI_0:.*]] = arith.minsi %[[VAL_1]], %[[VAL_2]] : i32
+// CHECK: scf.reduce.return %[[MINSI_0]] : i32
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[PARALLEL_0]] : i32
+// CHECK: }
+func.func @parallel_minsi_reduce(%arg0: memref<100xi32>) -> (i32) {
+ %12 = affine.parallel (%i) = (0) to (100) reduce ("mins") -> (i32) {
+ %2 = affine.load %arg0[%i] : memref<100xi32>
+ affine.yield %2 : i32
+ }
+ return %12 : i32
+}
+
+// CHECK-LABEL: func.func @parallel_maxui_reduce(
+// CHECK-SAME: %[[ARG0:.*]]: memref<100xi32>) -> i32 {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : index
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 100 : index
+// CHECK: %[[CONSTANT_2:.*]] = arith.constant 1 : index
+// CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
+// CHECK: %[[PARALLEL_0:.*]] = scf.parallel (%[[VAL_0:.*]]) = (%[[CONSTANT_0]]) to (%[[CONSTANT_1]]) step (%[[CONSTANT_2]]) init (%[[CONSTANT_3]]) -> i32 {
+// CHECK: %[[LOAD_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_0]]] : memref<100xi32>
+// CHECK: scf.reduce(%[[LOAD_0]] : i32) {
+// CHECK: ^bb0(%[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32):
+// CHECK: %[[MAXUI_0:.*]] = arith.maxui %[[VAL_1]], %[[VAL_2]] : i32
+// CHECK: scf.reduce.return %[[MAXUI_0]] : i32
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[PARALLEL_0]] : i32
+// CHECK: }
+func.func @parallel_maxui_reduce(%arg0: memref<100xi32>) -> (i32) {
+ %12 = affine.parallel (%i) = (0) to (100) reduce ("maxu") -> (i32) {
+ %2 = affine.load %arg0[%i] : memref<100xi32>
+ affine.yield %2 : i32
+ }
+ return %12 : i32
+}
+
+// CHECK-LABEL: func.func @parallel_minui_reduce(
+// CHECK-SAME: %[[ARG0:.*]]: memref<100xi32>) -> i32 {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : index
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 100 : index
+// CHECK: %[[CONSTANT_2:.*]] = arith.constant 1 : index
+// CHECK: %[[CONSTANT_3:.*]] = arith.constant -1 : i32
+// CHECK: %[[PARALLEL_0:.*]] = scf.parallel (%[[VAL_0:.*]]) = (%[[CONSTANT_0]]) to (%[[CONSTANT_1]]) step (%[[CONSTANT_2]]) init (%[[CONSTANT_3]]) -> i32 {
+// CHECK: %[[LOAD_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_0]]] : memref<100xi32>
+// CHECK: scf.reduce(%[[LOAD_0]] : i32) {
+// CHECK: ^bb0(%[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32):
+// CHECK: %[[MINUI_0:.*]] = arith.minui %[[VAL_1]], %[[VAL_2]] : i32
+// CHECK: scf.reduce.return %[[MINUI_0]] : i32
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[PARALLEL_0]] : i32
+// CHECK: }
+func.func @parallel_minui_reduce(%arg0: memref<100xi32>) -> (i32) {
+ %12 = affine.parallel (%i) = (0) to (100) reduce ("minu") -> (i32) {
+ %2 = affine.load %arg0[%i] : memref<100xi32>
+ affine.yield %2 : i32
+ }
+ return %12 : i32
+}
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index 1562f5b1693c0..2a178e975dacd 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -466,3 +466,62 @@ func.func @parallel_minnumf_reduce() {
return
}
+// -----
+
+// CHECK-LABEL: func.func @parallel_maxsi_reduce() {
+// CHECK: affine.parallel (%[[VAL_0:.*]]) = (0) to (100) reduce ("maxs", "maxs") -> (i32, si32) {
+func.func @parallel_maxsi_reduce() {
+ %0 = memref.alloc() : memref<100xi32>
+ %1 = memref.alloc() : memref<100xsi32>
+ %12:2 = affine.parallel (%i) = (0) to (100) reduce ("maxs", "maxs") -> (i32, si32) {
+ %2 = affine.load %0[%i] : memref<100xi32>
+ %3 = affine.load %1[%i] : memref<100xsi32>
+ affine.yield %2, %3 : i32, si32
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @parallel_minsi_reduce() {
+// CHECK: affine.parallel (%[[VAL_0:.*]]) = (0) to (100) reduce ("mins", "mins") -> (i32, si32) {
+func.func @parallel_minsi_reduce() {
+ %0 = memref.alloc() : memref<100xi32>
+ %1 = memref.alloc() : memref<100xsi32>
+ %12:2 = affine.parallel (%i) = (0) to (100) reduce ("mins", "mins") -> (i32, si32) {
+ %2 = affine.load %0[%i] : memref<100xi32>
+ %3 = affine.load %1[%i] : memref<100xsi32>
+ affine.yield %2, %3 : i32, si32
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @parallel_maxui_reduce() {
+// CHECK: affine.parallel (%[[VAL_0:.*]]) = (0) to (100) reduce ("maxu", "maxu") -> (i32, ui32) {
+func.func @parallel_maxui_reduce() {
+ %0 = memref.alloc() : memref<100xi32>
+ %1 = memref.alloc() : memref<100xui32>
+ %12:2 = affine.parallel (%i) = (0) to (100) reduce ("maxu", "maxu") -> (i32, ui32) {
+ %2 = affine.load %0[%i] : memref<100xi32>
+ %3 = affine.load %1[%i] : memref<100xui32>
+ affine.yield %2, %3 : i32, ui32
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @parallel_minui_reduce() {
+// CHECK: affine.parallel (%[[VAL_0:.*]]) = (0) to (100) reduce ("minu", "minu") -> (i32, ui32) {
+func.func @parallel_minui_reduce() {
+ %0 = memref.alloc() : memref<100xi32>
+ %1 = memref.alloc() : memref<100xui32>
+ %12:2 = affine.parallel (%i) = (0) to (100) reduce ("minu", "minu") -> (i32, ui32) {
+ %2 = affine.load %0[%i] : memref<100xi32>
+ %3 = affine.load %1[%i] : memref<100xui32>
+ affine.yield %2, %3 : i32, ui32
+ }
+ return
+}
More information about the Mlir-commits
mailing list