[Mlir-commits] [mlir] [mlir][affine] Support signless types for max/min/si/ui reductions. (PR #189480)

Slava Zakharin llvmlistbot at llvm.org
Mon Mar 30 13:59:26 PDT 2026


https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/189480

As long as the reduction kinds for max/min/si/ui reductions
define the interpretation of the integer arguments, we can safely
support signless types. Moreover, affine-to-standard conversion
does not even work for non-signless types, because
`mlir::arith::getIdentityValue` generates invalid identity values.


>From 88cd7567ae89f85532ac2301f673a7638f6664f1 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Mon, 30 Mar 2026 13:45:02 -0700
Subject: [PATCH] [mlir][affine] Support signless types for max/min/si/ui
 reductions.

As long as the reduction kinds for max/min/si/ui reductions
define the interpretation of the integer arguments, we can safely
support signless types. Moreover, affine-to-standard conversion
does not even work for non-signless types, because
`mlir::arith::getIdentityValue` generates invalid identity values.
---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      |  8 +-
 .../AffineToStandard/lower-affine.mlir        | 98 +++++++++++++++++++
 mlir/test/Dialect/Affine/ops.mlir             | 59 +++++++++++
 3 files changed, 161 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 839d34b41cbd4..18a7ffb24b2bc 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4218,19 +4218,19 @@ static bool isResultTypeMatchAtomicRMWKind(Type resultType,
     return isa<FloatType>(resultType);
   case arith::AtomicRMWKind::maxs: {
     auto intType = dyn_cast<IntegerType>(resultType);
-    return intType && intType.isSigned();
+    return intType && (intType.isSigned() || intType.isSignless());
   }
   case arith::AtomicRMWKind::mins: {
     auto intType = dyn_cast<IntegerType>(resultType);
-    return intType && intType.isSigned();
+    return intType && (intType.isSigned() || intType.isSignless());
   }
   case arith::AtomicRMWKind::maxu: {
     auto intType = dyn_cast<IntegerType>(resultType);
-    return intType && intType.isUnsigned();
+    return intType && (intType.isUnsigned() || intType.isSignless());
   }
   case arith::AtomicRMWKind::minu: {
     auto intType = dyn_cast<IntegerType>(resultType);
-    return intType && intType.isUnsigned();
+    return intType && (intType.isUnsigned() || intType.isSignless());
   }
   case arith::AtomicRMWKind::ori:
   case arith::AtomicRMWKind::andi:
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 550ea71882e14..2013c2899ca70 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -927,3 +927,101 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
 // CHECK:      scf.reduce.return %[[RES]] : i64
 // CHECK:    }
 // CHECK:  }
+
+/////////////////////////////////////////////////////////////////////
+
+// CHECK-LABEL:   func.func @parallel_maxsi_reduce(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<100xi32>) -> i32 {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[CONSTANT_1:.*]] = arith.constant 100 : index
+// CHECK:           %[[CONSTANT_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[CONSTANT_3:.*]] = arith.constant -2147483648 : i32
+// CHECK:           %[[PARALLEL_0:.*]] = scf.parallel (%[[VAL_0:.*]]) = (%[[CONSTANT_0]]) to (%[[CONSTANT_1]]) step (%[[CONSTANT_2]]) init (%[[CONSTANT_3]]) -> i32 {
+// CHECK:             %[[LOAD_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_0]]] : memref<100xi32>
+// CHECK:             scf.reduce(%[[LOAD_0]] : i32) {
+// CHECK:             ^bb0(%[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32):
+// CHECK:               %[[MAXSI_0:.*]] = arith.maxsi %[[VAL_1]], %[[VAL_2]] : i32
+// CHECK:               scf.reduce.return %[[MAXSI_0]] : i32
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return %[[PARALLEL_0]] : i32
+// CHECK:         }
+func.func @parallel_maxsi_reduce(%arg0: memref<100xi32>) -> (i32) {
+  %12 = affine.parallel (%i) = (0) to (100) reduce ("maxs") -> (i32) {
+    %2 = affine.load %arg0[%i] : memref<100xi32>
+    affine.yield %2 : i32
+  }
+  return %12 : i32
+}
+
+// CHECK-LABEL:   func.func @parallel_minsi_reduce(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<100xi32>) -> i32 {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[CONSTANT_1:.*]] = arith.constant 100 : index
+// CHECK:           %[[CONSTANT_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[CONSTANT_3:.*]] = arith.constant 2147483647 : i32
+// CHECK:           %[[PARALLEL_0:.*]] = scf.parallel (%[[VAL_0:.*]]) = (%[[CONSTANT_0]]) to (%[[CONSTANT_1]]) step (%[[CONSTANT_2]]) init (%[[CONSTANT_3]]) -> i32 {
+// CHECK:             %[[LOAD_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_0]]] : memref<100xi32>
+// CHECK:             scf.reduce(%[[LOAD_0]] : i32) {
+// CHECK:             ^bb0(%[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32):
+// CHECK:               %[[MINSI_0:.*]] = arith.minsi %[[VAL_1]], %[[VAL_2]] : i32
+// CHECK:               scf.reduce.return %[[MINSI_0]] : i32
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return %[[PARALLEL_0]] : i32
+// CHECK:         }
+func.func @parallel_minsi_reduce(%arg0: memref<100xi32>) -> (i32) {
+  %12 = affine.parallel (%i) = (0) to (100) reduce ("mins") -> (i32) {
+    %2 = affine.load %arg0[%i] : memref<100xi32>
+    affine.yield %2 : i32
+  }
+  return %12 : i32
+}
+
+// CHECK-LABEL:   func.func @parallel_maxui_reduce(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<100xi32>) -> i32 {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[CONSTANT_1:.*]] = arith.constant 100 : index
+// CHECK:           %[[CONSTANT_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[CONSTANT_3:.*]] = arith.constant 0 : i32
+// CHECK:           %[[PARALLEL_0:.*]] = scf.parallel (%[[VAL_0:.*]]) = (%[[CONSTANT_0]]) to (%[[CONSTANT_1]]) step (%[[CONSTANT_2]]) init (%[[CONSTANT_3]]) -> i32 {
+// CHECK:             %[[LOAD_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_0]]] : memref<100xi32>
+// CHECK:             scf.reduce(%[[LOAD_0]] : i32) {
+// CHECK:             ^bb0(%[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32):
+// CHECK:               %[[MAXUI_0:.*]] = arith.maxui %[[VAL_1]], %[[VAL_2]] : i32
+// CHECK:               scf.reduce.return %[[MAXUI_0]] : i32
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return %[[PARALLEL_0]] : i32
+// CHECK:         }
+func.func @parallel_maxui_reduce(%arg0: memref<100xi32>) -> (i32) {
+  %12 = affine.parallel (%i) = (0) to (100) reduce ("maxu") -> (i32) {
+    %2 = affine.load %arg0[%i] : memref<100xi32>
+    affine.yield %2 : i32
+  }
+  return %12 : i32
+}
+
+// CHECK-LABEL:   func.func @parallel_minui_reduce(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<100xi32>) -> i32 {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[CONSTANT_1:.*]] = arith.constant 100 : index
+// CHECK:           %[[CONSTANT_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[CONSTANT_3:.*]] = arith.constant -1 : i32
+// CHECK:           %[[PARALLEL_0:.*]] = scf.parallel (%[[VAL_0:.*]]) = (%[[CONSTANT_0]]) to (%[[CONSTANT_1]]) step (%[[CONSTANT_2]]) init (%[[CONSTANT_3]]) -> i32 {
+// CHECK:             %[[LOAD_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_0]]] : memref<100xi32>
+// CHECK:             scf.reduce(%[[LOAD_0]] : i32) {
+// CHECK:             ^bb0(%[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32):
+// CHECK:               %[[MINUI_0:.*]] = arith.minui %[[VAL_1]], %[[VAL_2]] : i32
+// CHECK:               scf.reduce.return %[[MINUI_0]] : i32
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return %[[PARALLEL_0]] : i32
+// CHECK:         }
+func.func @parallel_minui_reduce(%arg0: memref<100xi32>) -> (i32) {
+  %12 = affine.parallel (%i) = (0) to (100) reduce ("minu") -> (i32) {
+    %2 = affine.load %arg0[%i] : memref<100xi32>
+    affine.yield %2 : i32
+  }
+  return %12 : i32
+}
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index 1562f5b1693c0..2a178e975dacd 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -466,3 +466,62 @@ func.func @parallel_minnumf_reduce() {
   return
 }
 
+// -----
+
+// CHECK-LABEL:   func.func @parallel_maxsi_reduce() {
+// CHECK:           affine.parallel (%[[VAL_0:.*]]) = (0) to (100) reduce ("maxs", "maxs") -> (i32, si32) {
+func.func @parallel_maxsi_reduce() {
+  %0 = memref.alloc() : memref<100xi32>
+  %1 = memref.alloc() : memref<100xsi32>
+  %12:2 = affine.parallel (%i) = (0) to (100) reduce ("maxs", "maxs") -> (i32, si32) {
+    %2 = affine.load %0[%i] : memref<100xi32>
+    %3 = affine.load %1[%i] : memref<100xsi32>
+    affine.yield %2, %3 : i32, si32
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @parallel_minsi_reduce() {
+// CHECK:           affine.parallel (%[[VAL_0:.*]]) = (0) to (100) reduce ("mins", "mins") -> (i32, si32) {
+func.func @parallel_minsi_reduce() {
+  %0 = memref.alloc() : memref<100xi32>
+  %1 = memref.alloc() : memref<100xsi32>
+  %12:2 = affine.parallel (%i) = (0) to (100) reduce ("mins", "mins") -> (i32, si32) {
+    %2 = affine.load %0[%i] : memref<100xi32>
+    %3 = affine.load %1[%i] : memref<100xsi32>
+    affine.yield %2, %3 : i32, si32
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @parallel_maxui_reduce() {
+// CHECK:           affine.parallel (%[[VAL_0:.*]]) = (0) to (100) reduce ("maxu", "maxu") -> (i32, ui32) {
+func.func @parallel_maxui_reduce() {
+  %0 = memref.alloc() : memref<100xi32>
+  %1 = memref.alloc() : memref<100xui32>
+  %12:2 = affine.parallel (%i) = (0) to (100) reduce ("maxu", "maxu") -> (i32, ui32) {
+    %2 = affine.load %0[%i] : memref<100xi32>
+    %3 = affine.load %1[%i] : memref<100xui32>
+    affine.yield %2, %3 : i32, ui32
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @parallel_minui_reduce() {
+// CHECK:           affine.parallel (%[[VAL_0:.*]]) = (0) to (100) reduce ("minu", "minu") -> (i32, ui32) {
+func.func @parallel_minui_reduce() {
+  %0 = memref.alloc() : memref<100xi32>
+  %1 = memref.alloc() : memref<100xui32>
+  %12:2 = affine.parallel (%i) = (0) to (100) reduce ("minu", "minu") -> (i32, ui32) {
+    %2 = affine.load %0[%i] : memref<100xi32>
+    %3 = affine.load %1[%i] : memref<100xui32>
+    affine.yield %2, %3 : i32, ui32
+  }
+  return
+}



More information about the Mlir-commits mailing list