[Mlir-commits] [mlir] [mlir][gpu] Align reduction operations with vector combining kinds (PR #73423)

Jakub Kuderski llvmlistbot at llvm.org
Sun Nov 26 13:06:18 PST 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/73423

>From 3c561a50010a9a2dd100bfa18aa134b95560584a Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 25 Nov 2023 23:20:22 -0500
Subject: [PATCH 1/2] [mlir][gpu] Align reduction operations with vector
 combining kinds

The motivation for this change is explained in
https://github.com/llvm/llvm-project/issues/72354.

Before this change, we could not tell between signed/unsigned
minimum/maximum and NaN treatment for floating point values.

The mapping of old reduction operations to the new ones is as follows:
*  `min` --> `minsi` for ints, `minf` for floats
*  `max` --> `maxsi` for ints, `maxf` for floats

New reduction kinds not represented in the old enum: `minui`, `maxui`,
`minimumf`, `maximumf`.

As a next step, I would like to have a common definition of combining
kinds used by the `vector` and `gpu` dialects. Separately, the GPU to
SPIR-V lowering does not yet properly handle zero and NaN values -- the
behavior of floating point min/max group reductions is not specified by the
SPIR-V spec.

Issue: https://github.com/llvm/llvm-project/issues/72354
---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |  52 ++++++---
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        |  21 +++-
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp |  58 +++++++---
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |  39 +++++--
 .../GPU/Transforms/AllReduceLowering.cpp      |  51 ++++-----
 .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir     |  22 ++--
 .../Conversion/GPUToSPIRV/reductions.mlir     |  49 +++++---
 .../{all-reduce.mlir => all-reduce-add.mlir}  |   0
 ...l-reduce-max.mlir => all-reduce-maxf.mlir} |  72 +++++-------
 mlir/test/Dialect/GPU/invalid.mlir            | 108 ++++++++++++++++--
 10 files changed, 313 insertions(+), 159 deletions(-)
 rename mlir/test/Dialect/GPU/{all-reduce.mlir => all-reduce-add.mlir} (100%)
 rename mlir/test/Dialect/GPU/{all-reduce-max.mlir => all-reduce-maxf.mlir} (69%)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index e11c5c393648de7..424e30a32830890 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -868,25 +868,41 @@ def GPU_YieldOp : GPU_Op<"yield", [Pure, Terminator]>,
   }];
 }
 
-// add, mul mirror the XLA ComparisonDirection enum.
+// These mirror the reduction combining kinds from the vector dialect.
 def GPU_AllReduceOpAdd : I32EnumAttrCase<"ADD", 0, "add">;
-def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 1, "and">;
-def GPU_AllReduceOpMax : I32EnumAttrCase<"MAX", 2, "max">;
-def GPU_AllReduceOpMin : I32EnumAttrCase<"MIN", 3, "min">;
-def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 4, "mul">;
-def GPU_AllReduceOpOr  : I32EnumAttrCase<"OR",  5, "or">;
-def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 6, "xor">;
+def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 1, "mul">;
+def GPU_AllReduceOpMinUI : I32EnumAttrCase<"MINUI", 2, "minui">;
+def GPU_AllReduceOpMinSI : I32EnumAttrCase<"MINSI", 3, "minsi">;
+// Follows the `arith.minnumf` semantics.
+def GPU_AllReduceOpMinF : I32EnumAttrCase<"MINF", 4, "minf">;
+def GPU_AllReduceOpMaxUI : I32EnumAttrCase<"MAXUI", 5, "maxui">;
+def GPU_AllReduceOpMaxSI : I32EnumAttrCase<"MAXSI", 6, "maxsi">;
+// Follows the `arith.maxnumf` semantics.
+def GPU_AllReduceOpMaxF : I32EnumAttrCase<"MAXF", 7, "maxf">;
+def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 8, "and">;
+def GPU_AllReduceOpOr  : I32EnumAttrCase<"OR",  9, "or">;
+def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 10, "xor">;
+// Follows the `arith.minimumf` semantics.
+def GPU_AllReduceOpMinimumF : I32EnumAttrCase<"MINIMUMF", 11, "minimumf">;
+// Follows the `arith.maximumf` semantics.
+def GPU_AllReduceOpMaximumF : I32EnumAttrCase<"MAXIMUMF", 12, "maximumf">;
 
 def GPU_AllReduceOperation : I32EnumAttr<"AllReduceOperation",
     "built-in reduction operations supported by gpu.allreduce.",
     [
       GPU_AllReduceOpAdd,
-      GPU_AllReduceOpAnd,
-      GPU_AllReduceOpMax,
-      GPU_AllReduceOpMin,
       GPU_AllReduceOpMul,
+      GPU_AllReduceOpMinUI,
+      GPU_AllReduceOpMinSI,
+      GPU_AllReduceOpMinF,
+      GPU_AllReduceOpMaxUI,
+      GPU_AllReduceOpMaxSI,
+      GPU_AllReduceOpMaxF,
+      GPU_AllReduceOpAnd,
       GPU_AllReduceOpOr,
-      GPU_AllReduceOpXor
+      GPU_AllReduceOpXor,
+      GPU_AllReduceOpMinimumF,
+      GPU_AllReduceOpMaximumF
     ]>{
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::gpu";
@@ -918,8 +934,11 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
 
     compute the sum of each work item's %0 value. The first version specifies
     the accumulation as operation, whereas the second version specifies the
-    accumulation as code region. The accumulation operation must be one of:
-    `add`, `and`, `max`, `min`, `mul`, `or`, `xor`.
+    accumulation as code region. The reduction operation must be one of:
+    *  Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
+       `or`, `xor`
+    *  Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
+       `maximumf`
 
     If `uniform` flag is set either none or all work items of a workgroup
     need to execute this op in convergence.
@@ -951,7 +970,12 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
     ```
 
     If `uniform` flag is set either none or all work items of a subgroup
-    need to execute this op in convergence.
+    need to execute this op in convergence. The reduction operation must be one
+    of:
+    *  Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
+       `or`, `xor`
+    *  Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
+       `maximumf`
   }];
   let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
                           (`uniform` $uniform^)? attr-dict
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 86a77f557cb9579..d49f7adaee9ba74 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -65,17 +65,28 @@ convertReduxKind(gpu::AllReduceOperation mode) {
   switch (mode) {
   case gpu::AllReduceOperation::ADD:
     return NVVM::ReduxKind::ADD;
+  case gpu::AllReduceOperation::MUL:
+    return std::nullopt;
+  case gpu::AllReduceOperation::MINSI:
+    return NVVM::ReduxKind::MIN;
+  case gpu::AllReduceOperation::MINUI:
+    return std::nullopt;
+  case gpu::AllReduceOperation::MINF:
+    return NVVM::ReduxKind::MIN;
+  case gpu::AllReduceOperation::MAXSI:
+    return NVVM::ReduxKind::MAX;
+  case gpu::AllReduceOperation::MAXUI:
+    return std::nullopt;
+  case gpu::AllReduceOperation::MAXF:
+    return NVVM::ReduxKind::MAX;
   case gpu::AllReduceOperation::AND:
     return NVVM::ReduxKind::AND;
-  case gpu::AllReduceOperation::MAX:
-    return NVVM::ReduxKind::MAX;
-  case gpu::AllReduceOperation::MIN:
-    return NVVM::ReduxKind::MIN;
   case gpu::AllReduceOperation::OR:
     return NVVM::ReduxKind::OR;
   case gpu::AllReduceOperation::XOR:
     return NVVM::ReduxKind::XOR;
-  case gpu::AllReduceOperation::MUL:
+  case gpu::AllReduceOperation::MINIMUMF:
+  case gpu::AllReduceOperation::MAXIMUMF:
     return std::nullopt;
   }
   return std::nullopt;
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 693cc3f6236b574..272ebb8e357b40c 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -503,26 +503,52 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
     return std::nullopt;
   }
 
+  // TODO: The SPIR-V spec does not specify how -0.0 / +0.0 and NaN values are
+  // handled in *FMin/*FMax reduction ops. We should double account for this not
+  // being defined in this conversion.
+
   using ReduceType = gpu::AllReduceOperation;
-  namespace spv = spirv;
   const OpHandler handlers[] = {
       {ReduceType::ADD,
-       &createGroupReduceOpImpl<spv::GroupIAddOp, spv::GroupNonUniformIAddOp>,
-       &createGroupReduceOpImpl<spv::GroupFAddOp, spv::GroupNonUniformFAddOp>},
+       &createGroupReduceOpImpl<spirv::GroupIAddOp,
+                                spirv::GroupNonUniformIAddOp>,
+       &createGroupReduceOpImpl<spirv::GroupFAddOp,
+                                spirv::GroupNonUniformFAddOp>},
       {ReduceType::MUL,
-       &createGroupReduceOpImpl<spv::GroupIMulKHROp,
-                                spv::GroupNonUniformIMulOp>,
-       &createGroupReduceOpImpl<spv::GroupFMulKHROp,
-                                spv::GroupNonUniformFMulOp>},
-      {ReduceType::MIN,
-       &createGroupReduceOpImpl<spv::GroupSMinOp, spv::GroupNonUniformSMinOp>,
-       &createGroupReduceOpImpl<spv::GroupFMinOp, spv::GroupNonUniformFMinOp>},
-      {ReduceType::MAX,
-       &createGroupReduceOpImpl<spv::GroupSMaxOp, spv::GroupNonUniformSMaxOp>,
-       &createGroupReduceOpImpl<spv::GroupFMaxOp, spv::GroupNonUniformFMaxOp>},
-  };
-
-  for (auto &handler : handlers)
+       &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
+                                spirv::GroupNonUniformIMulOp>,
+       &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
+                                spirv::GroupNonUniformFMulOp>},
+      {ReduceType::MINUI,
+       &createGroupReduceOpImpl<spirv::GroupUMinOp,
+                                spirv::GroupNonUniformUMinOp>,
+       nullptr},
+      {ReduceType::MINSI,
+       &createGroupReduceOpImpl<spirv::GroupSMinOp,
+                                spirv::GroupNonUniformSMinOp>,
+       nullptr},
+      {ReduceType::MINF, nullptr,
+       &createGroupReduceOpImpl<spirv::GroupFMinOp,
+                                spirv::GroupNonUniformFMinOp>},
+      {ReduceType::MAXUI,
+       &createGroupReduceOpImpl<spirv::GroupUMaxOp,
+                                spirv::GroupNonUniformUMaxOp>,
+       nullptr},
+      {ReduceType::MAXSI,
+       &createGroupReduceOpImpl<spirv::GroupSMaxOp,
+                                spirv::GroupNonUniformSMaxOp>,
+       nullptr},
+      {ReduceType::MAXF, nullptr,
+       &createGroupReduceOpImpl<spirv::GroupFMaxOp,
+                                spirv::GroupNonUniformFMaxOp>},
+      {ReduceType::MINIMUMF, nullptr,
+       &createGroupReduceOpImpl<spirv::GroupFMinOp,
+                                spirv::GroupNonUniformFMinOp>},
+      {ReduceType::MAXIMUMF, nullptr,
+       &createGroupReduceOpImpl<spirv::GroupFMaxOp,
+                                spirv::GroupNonUniformFMaxOp>}};
+
+  for (const OpHandler &handler : handlers)
     if (handler.type == opType)
       return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 9517c053c8360ef..49a8d96e83c5a0b 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -27,7 +27,9 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -485,12 +487,23 @@ static LogicalResult verifyAttributions(Operation *op,
 // AllReduceOp
 //===----------------------------------------------------------------------===//
 
-static bool verifyReduceOpAndType(gpu::AllReduceOperation opName,
-                                  Type resType) {
-  return (opName != gpu::AllReduceOperation::AND &&
-          opName != gpu::AllReduceOperation::OR &&
-          opName != gpu::AllReduceOperation::XOR) ||
-         llvm::isa<IntegerType>(resType);
+static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
+                                           Type resType) {
+  using Kind = gpu::AllReduceOperation;
+  if (llvm::is_contained(
+          {Kind::MINF, Kind::MAXF, Kind::MINIMUMF, Kind::MAXIMUMF}, opName)) {
+    if (!isa<FloatType>(resType))
+      return failure();
+  }
+
+  if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
+                          Kind::AND, Kind::OR, Kind::XOR},
+                         opName)) {
+    if (!isa<IntegerType>(resType))
+      return failure();
+  }
+
+  return success();
 }
 
 LogicalResult gpu::AllReduceOp::verifyRegions() {
@@ -517,12 +530,13 @@ LogicalResult gpu::AllReduceOp::verifyRegions() {
       return emitError("expected gpu.yield op in region");
   } else {
     gpu::AllReduceOperation opName = *getOp();
-    if (!verifyReduceOpAndType(opName, getType())) {
-      return emitError()
-             << '`' << gpu::stringifyAllReduceOperation(opName)
-             << "` accumulator is only compatible with Integer type";
+    if (failed(verifyReduceOpAndType(opName, getType()))) {
+      return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
+                         << "` reduction operation is not compatible with type "
+                         << getType();
     }
   }
+
   return success();
 }
 
@@ -573,9 +587,10 @@ static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
 
 LogicalResult gpu::SubgroupReduceOp::verify() {
   gpu::AllReduceOperation opName = getOp();
-  if (!verifyReduceOpAndType(opName, getType())) {
+  if (failed(verifyReduceOpAndType(opName, getType()))) {
     return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
-                       << "` accumulator is only compatible with Integer type";
+                       << "` reduction operation is not compatible with type "
+                       << getType();
   }
   return success();
 }
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index acf4f6d0e3d6979..ecee9a7b45e32bd 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -214,32 +214,37 @@ struct GpuAllReduceRewriter {
 
   /// Returns an accumulator factory that creates an op specified by opName.
   AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
+    using Kind = gpu::AllReduceOperation;
     bool isFloatingPoint = isa<FloatType>(valueType);
     switch (opName) {
-    case gpu::AllReduceOperation::ADD:
+    case Kind::ADD:
       return isFloatingPoint ? getFactory<arith::AddFOp>()
                              : getFactory<arith::AddIOp>();
-    case gpu::AllReduceOperation::MUL:
+    case Kind::MUL:
       return isFloatingPoint ? getFactory<arith::MulFOp>()
                              : getFactory<arith::MulIOp>();
-    case gpu::AllReduceOperation::AND:
+    case Kind::MINSI:
+      return getFactory<arith::MinSIOp>();
+    case Kind::MINUI:
+      return getFactory<arith::MinUIOp>();
+    case Kind::MINF:
+      return getFactory<arith::MinNumFOp>();
+    case Kind::MAXSI:
+      return getFactory<arith::MaxSIOp>();
+    case Kind::MAXUI:
+      return getFactory<arith::MaxUIOp>();
+    case Kind::MAXF:
+      return getFactory<arith::MaxNumFOp>();
+    case Kind::AND:
       return getFactory<arith::AndIOp>();
-    case gpu::AllReduceOperation::OR:
+    case Kind::OR:
       return getFactory<arith::OrIOp>();
-    case gpu::AllReduceOperation::XOR:
+    case Kind::XOR:
       return getFactory<arith::XOrIOp>();
-    case gpu::AllReduceOperation::MAX:
-      return isFloatingPoint
-                 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
-                                 arith::CmpFPredicate::UGT>()
-                 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
-                                 arith::CmpIPredicate::ugt>();
-    case gpu::AllReduceOperation::MIN:
-      return isFloatingPoint
-                 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
-                                 arith::CmpFPredicate::ULT>()
-                 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
-                                 arith::CmpIPredicate::ult>();
+    case Kind::MINIMUMF:
+      return getFactory<arith::MinimumFOp>();
+    case Kind::MAXIMUMF:
+      return getFactory<arith::MaximumFOp>();
     }
     llvm_unreachable("unknown GPU AllReduceOperation");
   }
@@ -247,21 +252,11 @@ struct GpuAllReduceRewriter {
   /// Returns an accumulator factory that creates an op of type T.
   template <typename T>
   AccumulatorFactory getFactory() {
-    return [&](Value lhs, Value rhs) {
+    return [this](Value lhs, Value rhs) {
       return create<T>(lhs.getType(), lhs, rhs);
     };
   }
 
-  /// Returns an accumulator for comparison such as min, max. T is the type
-  /// of the compare op.
-  template <typename T, typename PredicateEnum, PredicateEnum predicate>
-  AccumulatorFactory getCmpFactory() const {
-    return [&](Value lhs, Value rhs) {
-      Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
-      return rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
-    };
-  }
-
   /// Creates an if-block skeleton and calls the two factories to generate the
   /// ops in the `then` and `else` block..
   ///
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index c18bb423a6e6001..20a200e812c1259 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -582,22 +582,22 @@ gpu.module @test_module_30 {
     %result = gpu.subgroup_reduce add %arg0 uniform {} : (i32) -> (i32)
     gpu.return
   }
-  // CHECK-LABEL: func @subgroup_reduce_and
-  gpu.func @subgroup_reduce_and(%arg0 : i32) {
-    // CHECK: nvvm.redux.sync and {{.*}}
-    %result = gpu.subgroup_reduce and %arg0 uniform {} : (i32) -> (i32)
+  // CHECK-LABEL: @subgroup_reduce_minsi
+  gpu.func @subgroup_reduce_minsi(%arg0 : i32) {
+    // CHECK: nvvm.redux.sync min {{.*}}
+    %result = gpu.subgroup_reduce minsi %arg0 uniform {} : (i32) -> (i32)
     gpu.return
   }
-  // CHECK-LABEL:  @subgroup_reduce_max
-  gpu.func @subgroup_reduce_max(%arg0 : i32) {
+  // CHECK-LABEL:  @subgroup_reduce_maxsi
+  gpu.func @subgroup_reduce_maxsi(%arg0 : i32) {
     // CHECK: nvvm.redux.sync max {{.*}}
-    %result = gpu.subgroup_reduce max %arg0 uniform {} : (i32) -> (i32)
+    %result = gpu.subgroup_reduce maxsi %arg0 uniform {} : (i32) -> (i32)
     gpu.return
   }
-  // CHECK-LABEL: @subgroup_reduce_min
-  gpu.func @subgroup_reduce_min(%arg0 : i32) {
-    // CHECK: nvvm.redux.sync min {{.*}}
-    %result = gpu.subgroup_reduce min %arg0 uniform {} : (i32) -> (i32)
+  // CHECK-LABEL: func @subgroup_reduce_and
+  gpu.func @subgroup_reduce_and(%arg0 : i32) {
+    // CHECK: nvvm.redux.sync and {{.*}}
+    %result = gpu.subgroup_reduce and %arg0 uniform {} : (i32) -> (i32)
     gpu.return
   }
   // CHECK-LABEL:  @subgroup_reduce_or
diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index 1e5d64387650ce3..feb7ee185a8a858 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -331,7 +331,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupFMin <Workgroup> <Reduce> %[[ARG]] : f32
-    %reduced = gpu.all_reduce min %arg uniform {} : (f32) -> (f32)
+    %reduced = gpu.all_reduce minf %arg uniform {} : (f32) -> (f32)
     gpu.return
   }
 }
@@ -351,7 +351,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformFMin "Workgroup" "Reduce" %[[ARG]] : f32
-    %reduced = gpu.all_reduce min %arg {} : (f32) -> (f32)
+    %reduced = gpu.all_reduce minf %arg {} : (f32) -> (f32)
     gpu.return
   }
 }
@@ -371,7 +371,9 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupSMin <Workgroup> <Reduce> %[[ARG]] : i32
-    %reduced = gpu.all_reduce min %arg uniform {} : (i32) -> (i32)
+    // CHECK: %{{.*}} = spirv.GroupUMin <Workgroup> <Reduce> %[[ARG]] : i32
+    %r0 = gpu.all_reduce minsi %arg uniform {} : (i32) -> (i32)
+    %r1 = gpu.all_reduce minui %arg uniform {} : (i32) -> (i32)
     gpu.return
   }
 }
@@ -390,8 +392,9 @@ gpu.module @kernels {
   //  CHECK-SAME: (%[[ARG:.*]]: i32)
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
-    // CHECK: %{{.*}} = spirv.GroupNonUniformSMin "Workgroup" "Reduce" %[[ARG]] : i32
-    %reduced = gpu.all_reduce min %arg {} : (i32) -> (i32)
+    // CHECK: %{{.*}} = spirv.GroupNonUniformUMin "Workgroup" "Reduce" %[[ARG]] : i32
+    %r0 = gpu.all_reduce minsi %arg {} : (i32) -> (i32)
+    %r1 = gpu.all_reduce minui %arg {} : (i32) -> (i32)
     gpu.return
   }
 }
@@ -411,7 +414,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupFMin <Subgroup> <Reduce> %[[ARG]] : f32
-    %reduced = gpu.subgroup_reduce min %arg uniform : (f32) -> (f32)
+    %reduced = gpu.subgroup_reduce minf %arg uniform : (f32) -> (f32)
     gpu.return
   }
 }
@@ -431,7 +434,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformFMin "Subgroup" "Reduce" %[[ARG]] : f32
-    %reduced = gpu.subgroup_reduce min %arg : (f32) -> (f32)
+    %reduced = gpu.subgroup_reduce minf %arg : (f32) -> (f32)
     gpu.return
   }
 }
@@ -451,7 +454,9 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupSMin <Subgroup> <Reduce> %[[ARG]] : i32
-    %reduced = gpu.subgroup_reduce min %arg uniform : (i32) -> (i32)
+    // CHECK: %{{.*}} = spirv.GroupUMin <Subgroup> <Reduce> %[[ARG]] : i32
+    %r0 = gpu.subgroup_reduce minsi %arg uniform : (i32) -> (i32)
+    %r1 = gpu.subgroup_reduce minui %arg uniform : (i32) -> (i32)
     gpu.return
   }
 }
@@ -471,7 +476,9 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformSMin "Subgroup" "Reduce" %[[ARG]] : i32
-    %reduced = gpu.subgroup_reduce min %arg : (i32) -> (i32)
+    // CHECK: %{{.*}} = spirv.GroupNonUniformUMin "Subgroup" "Reduce" %[[ARG]] : i32
+    %r0 = gpu.subgroup_reduce minsi %arg : (i32) -> (i32)
+    %r1 = gpu.subgroup_reduce minui %arg : (i32) -> (i32)
     gpu.return
   }
 }
@@ -491,7 +498,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupFMax <Workgroup> <Reduce> %[[ARG]] : f32
-    %reduced = gpu.all_reduce max %arg uniform {} : (f32) -> (f32)
+    %reduced = gpu.all_reduce maxf %arg uniform {} : (f32) -> (f32)
     gpu.return
   }
 }
@@ -511,7 +518,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformFMax "Workgroup" "Reduce" %[[ARG]] : f32
-    %reduced = gpu.all_reduce max %arg {} : (f32) -> (f32)
+    %reduced = gpu.all_reduce maxf %arg {} : (f32) -> (f32)
     gpu.return
   }
 }
@@ -531,7 +538,9 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupSMax <Workgroup> <Reduce> %[[ARG]] : i32
-    %reduced = gpu.all_reduce max %arg uniform {} : (i32) -> (i32)
+    // CHECK: %{{.*}} = spirv.GroupUMax <Workgroup> <Reduce> %[[ARG]] : i32
+    %r0 = gpu.all_reduce maxsi %arg uniform {} : (i32) -> (i32)
+    %r1 = gpu.all_reduce maxui %arg uniform {} : (i32) -> (i32)
     gpu.return
   }
 }
@@ -551,7 +560,9 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformSMax "Workgroup" "Reduce" %[[ARG]] : i32
-    %reduced = gpu.all_reduce max %arg {} : (i32) -> (i32)
+    // CHECK: %{{.*}} = spirv.GroupNonUniformUMax "Workgroup" "Reduce" %[[ARG]] : i32
+    %r0 = gpu.all_reduce maxsi %arg {} : (i32) -> (i32)
+    %r1 = gpu.all_reduce maxui %arg {} : (i32) -> (i32)
     gpu.return
   }
 }
@@ -571,7 +582,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupFMax <Subgroup> <Reduce> %[[ARG]] : f32
-    %reduced = gpu.subgroup_reduce max %arg uniform : (f32) -> (f32)
+    %reduced = gpu.subgroup_reduce maxf %arg uniform : (f32) -> (f32)
     gpu.return
   }
 }
@@ -591,7 +602,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformFMax "Subgroup" "Reduce" %[[ARG]] : f32
-    %reduced = gpu.subgroup_reduce max %arg : (f32) -> (f32)
+    %reduced = gpu.subgroup_reduce maxf %arg : (f32) -> (f32)
     gpu.return
   }
 }
@@ -611,7 +622,9 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupSMax <Subgroup> <Reduce> %[[ARG]] : i32
-    %reduced = gpu.subgroup_reduce max %arg uniform : (i32) -> (i32)
+    // CHECK: %{{.*}} = spirv.GroupUMax <Subgroup> <Reduce> %[[ARG]] : i32
+    %r0 = gpu.subgroup_reduce maxsi %arg uniform : (i32) -> (i32)
+    %r1 = gpu.subgroup_reduce maxui %arg uniform : (i32) -> (i32)
     gpu.return
   }
 }
@@ -631,7 +644,9 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformSMax "Subgroup" "Reduce" %[[ARG]] : i32
-    %reduced = gpu.subgroup_reduce max %arg : (i32) -> (i32)
+    // CHECK: %{{.*}} = spirv.GroupNonUniformUMax "Subgroup" "Reduce" %[[ARG]] : i32
+    %r0 = gpu.subgroup_reduce maxsi %arg : (i32) -> (i32)
+    %r1 = gpu.subgroup_reduce maxui %arg : (i32) -> (i32)
     gpu.return
   }
 }
diff --git a/mlir/test/Dialect/GPU/all-reduce.mlir b/mlir/test/Dialect/GPU/all-reduce-add.mlir
similarity index 100%
rename from mlir/test/Dialect/GPU/all-reduce.mlir
rename to mlir/test/Dialect/GPU/all-reduce-add.mlir
diff --git a/mlir/test/Dialect/GPU/all-reduce-max.mlir b/mlir/test/Dialect/GPU/all-reduce-maxf.mlir
similarity index 69%
rename from mlir/test/Dialect/GPU/all-reduce-max.mlir
rename to mlir/test/Dialect/GPU/all-reduce-maxf.mlir
index a71544ba0e98d36..b502e587637cdc8 100644
--- a/mlir/test/Dialect/GPU/all-reduce-max.mlir
+++ b/mlir/test/Dialect/GPU/all-reduce-maxf.mlir
@@ -44,65 +44,55 @@ gpu.module @kernels {
     // CHECK:   [[VAL_34:%.*]], [[VAL_35:%.*]] = gpu.shuffle xor [[VAL_0]], [[VAL_6]], [[VAL_32]] : f32
     // CHECK:   cf.cond_br [[VAL_35]], ^bb2, ^bb3
     // CHECK: ^bb2:
-    // CHECK:   [[VAL_36:%.*]] = arith.cmpf ugt, [[VAL_0]], [[VAL_34]] : f32
-    // CHECK:   [[VAL_37:%.*]] = arith.select [[VAL_36]], [[VAL_0]], [[VAL_34]] : f32
-    // CHECK:   cf.br ^bb4([[VAL_37]] : f32)
+    // CHECK:   [[VAL_36:%.*]] = arith.maxnumf [[VAL_0]], [[VAL_34]] : f32
+    // CHECK:   cf.br ^bb4([[VAL_36]] : f32)
     // CHECK: ^bb3:
     // CHECK:   cf.br ^bb4([[VAL_0]] : f32)
     // CHECK: ^bb4([[VAL_38:%.*]]: f32):
     // CHECK:   [[VAL_39:%.*]], [[VAL_40:%.*]] = gpu.shuffle xor [[VAL_38]], [[VAL_7]], [[VAL_32]] : f32
     // CHECK:   cf.cond_br [[VAL_40]], ^bb5, ^bb6
     // CHECK: ^bb5:
-    // CHECK:   [[VAL_41:%.*]] = arith.cmpf ugt, [[VAL_38]], [[VAL_39]] : f32
-    // CHECK:   [[VAL_42:%.*]] = arith.select [[VAL_41]], [[VAL_38]], [[VAL_39]] : f32
-    // CHECK:   cf.br ^bb7([[VAL_42]] : f32)
+    // CHECK:   [[VAL_41:%.*]] = arith.maxnumf [[VAL_38]], [[VAL_39]] : f32
+    // CHECK:   cf.br ^bb7([[VAL_41]] : f32)
     // CHECK: ^bb6:
     // CHECK:   cf.br ^bb7([[VAL_38]] : f32)
     // CHECK: ^bb7([[VAL_43:%.*]]: f32):
     // CHECK:   [[VAL_44:%.*]], [[VAL_45:%.*]] = gpu.shuffle xor [[VAL_43]], [[VAL_8]], [[VAL_32]] : f32
     // CHECK:   cf.cond_br [[VAL_45]], ^bb8, ^bb9
     // CHECK: ^bb8:
-    // CHECK:   [[VAL_46:%.*]] = arith.cmpf ugt, [[VAL_43]], [[VAL_44]] : f32
-    // CHECK:   [[VAL_47:%.*]] = arith.select [[VAL_46]], [[VAL_43]], [[VAL_44]] : f32
-    // CHECK:   cf.br ^bb10([[VAL_47]] : f32)
+    // CHECK:   [[VAL_46:%.*]] = arith.maxnumf [[VAL_43]], [[VAL_44]] : f32
+    // CHECK:   cf.br ^bb10([[VAL_46]] : f32)
     // CHECK: ^bb9:
     // CHECK:   cf.br ^bb10([[VAL_43]] : f32)
     // CHECK: ^bb10([[VAL_48:%.*]]: f32):
     // CHECK:   [[VAL_49:%.*]], [[VAL_50:%.*]] = gpu.shuffle xor [[VAL_48]], [[VAL_9]], [[VAL_32]] : f32
     // CHECK:   cf.cond_br [[VAL_50]], ^bb11, ^bb12
     // CHECK: ^bb11:
-    // CHECK:   [[VAL_51:%.*]] = arith.cmpf ugt, [[VAL_48]], [[VAL_49]] : f32
-    // CHECK:   [[VAL_52:%.*]] = arith.select [[VAL_51]], [[VAL_48]], [[VAL_49]] : f32
-    // CHECK:   cf.br ^bb13([[VAL_52]] : f32)
+    // CHECK:   [[VAL_51:%.*]] = arith.maxnumf [[VAL_48]], [[VAL_49]] : f32
+    // CHECK:   cf.br ^bb13([[VAL_51]] : f32)
     // CHECK: ^bb12:
     // CHECK:   cf.br ^bb13([[VAL_48]] : f32)
     // CHECK: ^bb13([[VAL_53:%.*]]: f32):
     // CHECK:   [[VAL_54:%.*]], [[VAL_55:%.*]] = gpu.shuffle xor [[VAL_53]], [[VAL_10]], [[VAL_32]] : f32
     // CHECK:   cf.cond_br [[VAL_55]], ^bb14, ^bb15
     // CHECK: ^bb14:
-    // CHECK:   [[VAL_56:%.*]] = arith.cmpf ugt, [[VAL_53]], [[VAL_54]] : f32
-    // CHECK:   [[VAL_57:%.*]] = arith.select [[VAL_56]], [[VAL_53]], [[VAL_54]] : f32
-    // CHECK:   cf.br ^bb16([[VAL_57]] : f32)
+    // CHECK:   [[VAL_56:%.*]] = arith.maxnumf [[VAL_53]], [[VAL_54]] : f32
+    // CHECK:   cf.br ^bb16([[VAL_56]] : f32)
     // CHECK: ^bb15:
     // CHECK:   cf.br ^bb16([[VAL_53]] : f32)
     // CHECK: ^bb16([[VAL_58:%.*]]: f32):
     // CHECK:   cf.br ^bb18([[VAL_58]] : f32)
     // CHECK: ^bb17:
     // CHECK:   [[VAL_59:%.*]], [[VAL_60:%.*]] = gpu.shuffle xor [[VAL_0]], [[VAL_6]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_61:%.*]] = arith.cmpf ugt, [[VAL_0]], [[VAL_59]] : f32
-    // CHECK:   [[VAL_62:%.*]] = arith.select [[VAL_61]], [[VAL_0]], [[VAL_59]] : f32
+    // CHECK:   [[VAL_62:%.*]] = arith.maxnumf [[VAL_0]], [[VAL_59]] : f32
     // CHECK:   [[VAL_63:%.*]], [[VAL_64:%.*]] = gpu.shuffle xor [[VAL_62]], [[VAL_7]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_65:%.*]] = arith.cmpf ugt, [[VAL_62]], [[VAL_63]] : f32
-    // CHECK:   [[VAL_66:%.*]] = arith.select [[VAL_65]], [[VAL_62]], [[VAL_63]] : f32
+    // CHECK:   [[VAL_66:%.*]] = arith.maxnumf [[VAL_62]], [[VAL_63]] : f32
     // CHECK:   [[VAL_67:%.*]], [[VAL_68:%.*]] = gpu.shuffle xor [[VAL_66]], [[VAL_8]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_69:%.*]] = arith.cmpf ugt, [[VAL_66]], [[VAL_67]] : f32
-    // CHECK:   [[VAL_70:%.*]] = arith.select [[VAL_69]], [[VAL_66]], [[VAL_67]] : f32
+    // CHECK:   [[VAL_70:%.*]] = arith.maxnumf [[VAL_66]], [[VAL_67]] : f32
     // CHECK:   [[VAL_71:%.*]], [[VAL_72:%.*]] = gpu.shuffle xor [[VAL_70]], [[VAL_9]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_73:%.*]] = arith.cmpf ugt, [[VAL_70]], [[VAL_71]] : f32
-    // CHECK:   [[VAL_74:%.*]] = arith.select [[VAL_73]], [[VAL_70]], [[VAL_71]] : f32
+    // CHECK:   [[VAL_74:%.*]] = arith.maxnumf [[VAL_70]], [[VAL_71]] : f32
     // CHECK:   [[VAL_75:%.*]], [[VAL_76:%.*]] = gpu.shuffle xor [[VAL_74]], [[VAL_10]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_77:%.*]] = arith.cmpf ugt, [[VAL_74]], [[VAL_75]] : f32
-    // CHECK:   [[VAL_78:%.*]] = arith.select [[VAL_77]], [[VAL_74]], [[VAL_75]] : f32
+    // CHECK:   [[VAL_78:%.*]] = arith.maxnumf [[VAL_74]], [[VAL_75]] : f32
     // CHECK:   cf.br ^bb18([[VAL_78]] : f32)
     // CHECK: ^bb18([[VAL_79:%.*]]: f32):
     // CHECK:   cf.cond_br [[VAL_30]], ^bb19, ^bb20
@@ -128,8 +118,7 @@ gpu.module @kernels {
     // CHECK:   [[VAL_88:%.*]], [[VAL_89:%.*]] = gpu.shuffle xor [[VAL_86]], [[VAL_6]], [[VAL_83]] : f32
     // CHECK:   cf.cond_br [[VAL_89]], ^bb24, ^bb25
     // CHECK: ^bb24:
-    // CHECK:   [[VAL_90:%.*]] = arith.cmpf ugt, [[VAL_86]], [[VAL_88]] : f32
-    // CHECK:   [[VAL_91:%.*]] = arith.select [[VAL_90]], [[VAL_86]], [[VAL_88]] : f32
+    // CHECK:   [[VAL_91:%.*]] = arith.maxnumf [[VAL_86]], [[VAL_88]] : f32
     // CHECK:   cf.br ^bb26([[VAL_91]] : f32)
     // CHECK: ^bb25:
     // CHECK:   cf.br ^bb26([[VAL_86]] : f32)
@@ -137,8 +126,7 @@ gpu.module @kernels {
     // CHECK:   [[VAL_93:%.*]], [[VAL_94:%.*]] = gpu.shuffle xor [[VAL_92]], [[VAL_7]], [[VAL_83]] : f32
     // CHECK:   cf.cond_br [[VAL_94]], ^bb27, ^bb28
     // CHECK: ^bb27:
-    // CHECK:   [[VAL_95:%.*]] = arith.cmpf ugt, [[VAL_92]], [[VAL_93]] : f32
-    // CHECK:   [[VAL_96:%.*]] = arith.select [[VAL_95]], [[VAL_92]], [[VAL_93]] : f32
+    // CHECK:   [[VAL_96:%.*]] = arith.maxnumf [[VAL_92]], [[VAL_93]] : f32
     // CHECK:   cf.br ^bb29([[VAL_96]] : f32)
     // CHECK: ^bb28:
     // CHECK:   cf.br ^bb29([[VAL_92]] : f32)
@@ -146,8 +134,7 @@ gpu.module @kernels {
     // CHECK:   [[VAL_98:%.*]], [[VAL_99:%.*]] = gpu.shuffle xor [[VAL_97]], [[VAL_8]], [[VAL_83]] : f32
     // CHECK:   cf.cond_br [[VAL_99]], ^bb30, ^bb31
     // CHECK: ^bb30:
-    // CHECK:   [[VAL_100:%.*]] = arith.cmpf ugt, [[VAL_97]], [[VAL_98]] : f32
-    // CHECK:   [[VAL_101:%.*]] = arith.select [[VAL_100]], [[VAL_97]], [[VAL_98]] : f32
+    // CHECK:   [[VAL_101:%.*]] = arith.maxnumf [[VAL_97]], [[VAL_98]] : f32
     // CHECK:   cf.br ^bb32([[VAL_101]] : f32)
     // CHECK: ^bb31:
     // CHECK:   cf.br ^bb32([[VAL_97]] : f32)
@@ -155,8 +142,7 @@ gpu.module @kernels {
     // CHECK:   [[VAL_103:%.*]], [[VAL_104:%.*]] = gpu.shuffle xor [[VAL_102]], [[VAL_9]], [[VAL_83]] : f32
     // CHECK:   cf.cond_br [[VAL_104]], ^bb33, ^bb34
     // CHECK: ^bb33:
-    // CHECK:   [[VAL_105:%.*]] = arith.cmpf ugt, [[VAL_102]], [[VAL_103]] : f32
-    // CHECK:   [[VAL_106:%.*]] = arith.select [[VAL_105]], [[VAL_102]], [[VAL_103]] : f32
+    // CHECK:   [[VAL_106:%.*]] = arith.maxnumf [[VAL_102]], [[VAL_103]] : f32
     // CHECK:   cf.br ^bb35([[VAL_106]] : f32)
     // CHECK: ^bb34:
     // CHECK:   cf.br ^bb35([[VAL_102]] : f32)
@@ -164,8 +150,7 @@ gpu.module @kernels {
     // CHECK:   [[VAL_108:%.*]], [[VAL_109:%.*]] = gpu.shuffle xor [[VAL_107]], [[VAL_10]], [[VAL_83]] : f32
     // CHECK:   cf.cond_br [[VAL_109]], ^bb36, ^bb37
     // CHECK: ^bb36:
-    // CHECK:   [[VAL_110:%.*]] = arith.cmpf ugt, [[VAL_107]], [[VAL_108]] : f32
-    // CHECK:   [[VAL_111:%.*]] = arith.select [[VAL_110]], [[VAL_107]], [[VAL_108]] : f32
+    // CHECK:   [[VAL_111:%.*]] = arith.maxnumf [[VAL_107]], [[VAL_108]] : f32
     // CHECK:   cf.br ^bb38([[VAL_111]] : f32)
     // CHECK: ^bb37:
     // CHECK:   cf.br ^bb38([[VAL_107]] : f32)
@@ -173,20 +158,15 @@ gpu.module @kernels {
     // CHECK:   cf.br ^bb40([[VAL_112]] : f32)
     // CHECK: ^bb39:
     // CHECK:   [[VAL_113:%.*]], [[VAL_114:%.*]] = gpu.shuffle xor [[VAL_86]], [[VAL_6]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_115:%.*]] = arith.cmpf ugt, [[VAL_86]], [[VAL_113]] : f32
-    // CHECK:   [[VAL_116:%.*]] = arith.select [[VAL_115]], [[VAL_86]], [[VAL_113]] : f32
+    // CHECK:   [[VAL_116:%.*]] = arith.maxnumf [[VAL_86]], [[VAL_113]] : f32
     // CHECK:   [[VAL_117:%.*]], [[VAL_118:%.*]] = gpu.shuffle xor [[VAL_116]], [[VAL_7]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_119:%.*]] = arith.cmpf ugt, [[VAL_116]], [[VAL_117]] : f32
-    // CHECK:   [[VAL_120:%.*]] = arith.select [[VAL_119]], [[VAL_116]], [[VAL_117]] : f32
+    // CHECK:   [[VAL_120:%.*]] = arith.maxnumf [[VAL_116]], [[VAL_117]] : f32
     // CHECK:   [[VAL_121:%.*]], [[VAL_122:%.*]] = gpu.shuffle xor [[VAL_120]], [[VAL_8]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_123:%.*]] = arith.cmpf ugt, [[VAL_120]], [[VAL_121]] : f32
-    // CHECK:   [[VAL_124:%.*]] = arith.select [[VAL_123]], [[VAL_120]], [[VAL_121]] : f32
+    // CHECK:   [[VAL_124:%.*]] = arith.maxnumf [[VAL_120]], [[VAL_121]] : f32
     // CHECK:   [[VAL_125:%.*]], [[VAL_126:%.*]] = gpu.shuffle xor [[VAL_124]], [[VAL_9]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_127:%.*]] = arith.cmpf ugt, [[VAL_124]], [[VAL_125]] : f32
-    // CHECK:   [[VAL_128:%.*]] = arith.select [[VAL_127]], [[VAL_124]], [[VAL_125]] : f32
+    // CHECK:   [[VAL_128:%.*]] = arith.maxnumf [[VAL_124]], [[VAL_125]] : f32
     // CHECK:   [[VAL_129:%.*]], [[VAL_130:%.*]] = gpu.shuffle xor [[VAL_128]], [[VAL_10]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_131:%.*]] = arith.cmpf ugt, [[VAL_128]], [[VAL_129]] : f32
-    // CHECK:   [[VAL_132:%.*]] = arith.select [[VAL_131]], [[VAL_128]], [[VAL_129]] : f32
+    // CHECK:   [[VAL_132:%.*]] = arith.maxnumf [[VAL_128]], [[VAL_129]] : f32
     // CHECK:   cf.br ^bb40([[VAL_132]] : f32)
     // CHECK: ^bb40([[VAL_133:%.*]]: f32):
     // CHECK:   store [[VAL_133]], [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, #gpu.address_space<workgroup>>
@@ -195,7 +175,7 @@ gpu.module @kernels {
     // CHECK:   cf.br ^bb42
     // CHECK: ^bb42:
     // CHECK:   gpu.barrier
-    %sum = gpu.all_reduce max %arg0 uniform {} : (f32) -> (f32)
+    %sum = gpu.all_reduce maxf %arg0 uniform {} : (f32) -> (f32)
     gpu.return
   }
 
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index df9921ef14d3b51..c8f5c4b40e85124 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -237,22 +237,110 @@ func.func @reduce_invalid_op(%arg0 : f32) {
 
 // -----
 
-func.func @reduce_invalid_op_type(%arg0 : f32) {
-  // expected-error at +1 {{`and` accumulator is only compatible with Integer type}}
+func.func @reduce_invalid_op_type_minsi(%arg0 : f32) {
+  // expected-error at +1 {{`minsi` reduction operation is not compatible with type 'f32'}}
+  %res = gpu.all_reduce minsi %arg0 {} : (f32) -> (f32)
+  return
+}
+
+// -----
+
+func.func @reduce_invalid_op_type_minui(%arg0 : f32) {
+  // expected-error at +1 {{`minui` reduction operation is not compatible with type 'f32'}}
+  %res = gpu.all_reduce minui %arg0 {} : (f32) -> (f32)
+  return
+}
+
+// -----
+
+func.func @reduce_invalid_op_type_maxsi(%arg0 : f32) {
+  // expected-error at +1 {{`maxsi` reduction operation is not compatible with type 'f32'}}
+  %res = gpu.all_reduce maxsi %arg0 {} : (f32) -> (f32)
+  return
+}
+
+// -----
+
+func.func @reduce_invalid_op_type_maxui(%arg0 : f32) {
+  // expected-error at +1 {{`maxui` reduction operation is not compatible with type 'f32'}}
+  %res = gpu.all_reduce maxui %arg0 {} : (f32) -> (f32)
+  return
+}
+
+// -----
+
+func.func @reduce_invalid_op_type_and(%arg0 : f32) {
+  // expected-error at +1 {{`and` reduction operation is not compatible with type 'f32'}}
   %res = gpu.all_reduce and %arg0 {} : (f32) -> (f32)
   return
 }
 
 // -----
 
-func.func @subgroup_reduce_invalid_op_type(%arg0 : f32) {
-  // expected-error at +1 {{`and` accumulator is only compatible with Integer type}}
+func.func @reduce_invalid_op_type_or(%arg0 : f32) {
+  // expected-error at +1 {{`or` reduction operation is not compatible with type 'f32'}}
+  %res = gpu.all_reduce or %arg0 {} : (f32) -> (f32)
+  return
+}
+
+// -----
+
+func.func @reduce_invalid_op_type_xor(%arg0 : f32) {
+  // expected-error at +1 {{`xor` reduction operation is not compatible with type 'f32'}}
+  %res = gpu.all_reduce xor %arg0 {} : (f32) -> (f32)
+  return
+}
+
+// -----
+
+func.func @reduce_invalid_op_type_minf(%arg0 : i32) {
+  // expected-error at +1 {{`minf` reduction operation is not compatible with type 'i32'}}
+  %res = gpu.all_reduce minf %arg0 {} : (i32) -> (i32)
+  return
+}
+
+// -----
+
+func.func @reduce_invalid_op_type_maxf(%arg0 : i32) {
+  // expected-error at +1 {{`maxf` reduction operation is not compatible with type 'i32'}}
+  %res = gpu.all_reduce maxf %arg0 {} : (i32) -> (i32)
+  return
+}
+
+// -----
+
+func.func @reduce_invalid_op_type_minimumf(%arg0 : i32) {
+  // expected-error at +1 {{`minimumf` reduction operation is not compatible with type 'i32'}}
+  %res = gpu.all_reduce minimumf %arg0 {} : (i32) -> (i32)
+  return
+}
+
+// -----
+
+func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {
+  // expected-error at +1 {{`maximumf` reduction operation is not compatible with type 'i32'}}
+  %res = gpu.all_reduce maximumf %arg0 {} : (i32) -> (i32)
+  return
+}
+
+// -----
+
+func.func @subgroup_reduce_invalid_op_type_and(%arg0 : f32) {
+  // expected-error at +1 {{`and` reduction operation is not compatible with type 'f32'}}
   %res = gpu.subgroup_reduce and %arg0 : (f32) -> (f32)
   return
 }
 
 // -----
 
+func.func @subgroup_reduce_invalid_op_type_maxf(%arg0 : i32) {
+  // expected-error at +1 {{`maxf` reduction operation is not compatible with type 'i32'}}
+  %res = gpu.subgroup_reduce maxf %arg0 : (i32) -> (i32)
+  return
+}
+
+// -----
+
 func.func @reduce_incorrect_region_arguments(%arg0 : f32) {
   // expected-error at +1 {{expected two region arguments}}
   %res = gpu.all_reduce %arg0 {
@@ -647,11 +735,11 @@ func.func @main() {
   %shmemSize = arith.constant 10000 : i32
   %c1 = arith.constant 1 : index
   gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
-             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
              dynamic_shared_memory_size %shmemSize
   {
     // expected-error @below {{'gpu.dynamic_shared_memory' op address space must be address_space<workgroup>}}
-    %0 = gpu.dynamic_shared_memory : memref<?xi8>  
+    %0 = gpu.dynamic_shared_memory : memref<?xi8>
     gpu.terminator
   }
   return
@@ -664,11 +752,11 @@ func.func @main() {
   %shmemSize = arith.constant 8192 : i32
   %c1 = arith.constant 1 : index
   gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
-             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
              dynamic_shared_memory_size %shmemSize
   {
     // expected-error @below {{'gpu.dynamic_shared_memory' op result memref type must be memref<?xi8, #gpu.address_space<workgroup>>}}
-    %0 = gpu.dynamic_shared_memory : memref<1xi8, #gpu.address_space<workgroup>>  
+    %0 = gpu.dynamic_shared_memory : memref<1xi8, #gpu.address_space<workgroup>>
     gpu.terminator
   }
   return
@@ -680,7 +768,7 @@ func.func @main(%arg0 : index) {
   %shmemSize = arith.constant 8192 : i32
   %c1 = arith.constant 1 : index
   gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
-             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
              dynamic_shared_memory_size %shmemSize
   {
     // expected-error @below {{'gpu.dynamic_shared_memory' op address space must be address_space<workgroup>}}
@@ -696,7 +784,7 @@ func.func @main(%arg0 : index) {
   %shmemSize = arith.constant 8192 : i32
   %c1 = arith.constant 1 : index
   gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
-             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
              dynamic_shared_memory_size %shmemSize
   {
     // expected-error @below {{'gpu.dynamic_shared_memory' op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, #gpu.address_space<workgroup>}}

>From a820424051f1dc18d529e7315a69286f078e0ce0 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 26 Nov 2023 16:05:51 -0500
Subject: [PATCH 2/2] Point to tracking issue for fp precision in spirv
 conversions

---
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 272ebb8e357b40c..6536bbe1f4dba47 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -503,9 +503,10 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
     return std::nullopt;
   }
 
-  // TODO: The SPIR-V spec does not specify how -0.0 / +0.0 and NaN values are
-  // handled in *FMin/*FMax reduction ops. We should double account for this not
-  // being defined in this conversion.
+  // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
+  // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
+  // reduction ops. We should account possible precision requirements in this
+  // conversion.
 
   using ReduceType = gpu::AllReduceOperation;
   const OpHandler handlers[] = {



More information about the Mlir-commits mailing list