[Mlir-commits] [mlir] dd09221 - Revert "[mlir][gpu] Align reduction operations with vector combining kinds (#73423)"

Jakub Kuderski llvmlistbot at llvm.org
Mon Nov 27 08:30:12 PST 2023


Author: Jakub Kuderski
Date: 2023-11-27T11:29:23-05:00
New Revision: dd09221a29506031415cad8a1308998358633d48

URL: https://github.com/llvm/llvm-project/commit/dd09221a29506031415cad8a1308998358633d48
DIFF: https://github.com/llvm/llvm-project/commit/dd09221a29506031415cad8a1308998358633d48.diff

LOG: Revert "[mlir][gpu] Align reduction operations with vector combining kinds (#73423)"

This reverts commit e0aac8c88d0d30e8da0f8a240ad1e6b4d88782e0.

I'm seeing some nvidia integration test failures:
https://lab.llvm.org/buildbot/#/builders/61/builds/52334.

Added: 
    mlir/test/Dialect/GPU/all-reduce-max.mlir
    mlir/test/Dialect/GPU/all-reduce.mlir

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/include/mlir/IR/CommonTypeConstraints.td
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
    mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
    mlir/test/Conversion/GPUToSPIRV/reductions.mlir
    mlir/test/Dialect/GPU/invalid.mlir

Removed: 
    mlir/test/Dialect/GPU/all-reduce-add.mlir
    mlir/test/Dialect/GPU/all-reduce-maxf.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 7cad1cd89fd6335..826df0012fb8f0a 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -931,53 +931,38 @@ def GPU_YieldOp : GPU_Op<"yield", [Pure, Terminator]>,
   }];
 }
 
-// These mirror the reduction combining kinds from the vector dialect.
+// add, mul mirror the XLA ComparisonDirection enum.
 def GPU_AllReduceOpAdd : I32EnumAttrCase<"ADD", 0, "add">;
-def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 1, "mul">;
-def GPU_AllReduceOpMinUI : I32EnumAttrCase<"MINUI", 2, "minui">;
-def GPU_AllReduceOpMinSI : I32EnumAttrCase<"MINSI", 3, "minsi">;
-// Follows the `arith.minnumf` semantics.
-def GPU_AllReduceOpMinF : I32EnumAttrCase<"MINF", 4, "minf">;
-def GPU_AllReduceOpMaxUI : I32EnumAttrCase<"MAXUI", 5, "maxui">;
-def GPU_AllReduceOpMaxSI : I32EnumAttrCase<"MAXSI", 6, "maxsi">;
-// Follows the `arith.maxnumf` semantics.
-def GPU_AllReduceOpMaxF : I32EnumAttrCase<"MAXF", 7, "maxf">;
-def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 8, "and">;
-def GPU_AllReduceOpOr  : I32EnumAttrCase<"OR",  9, "or">;
-def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 10, "xor">;
-// Follows the `arith.minimumf` semantics.
-def GPU_AllReduceOpMinimumF : I32EnumAttrCase<"MINIMUMF", 11, "minimumf">;
-// Follows the `arith.maximumf` semantics.
-def GPU_AllReduceOpMaximumF : I32EnumAttrCase<"MAXIMUMF", 12, "maximumf">;
+def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 1, "and">;
+def GPU_AllReduceOpMax : I32EnumAttrCase<"MAX", 2, "max">;
+def GPU_AllReduceOpMin : I32EnumAttrCase<"MIN", 3, "min">;
+def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 4, "mul">;
+def GPU_AllReduceOpOr  : I32EnumAttrCase<"OR",  5, "or">;
+def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 6, "xor">;
 
 def GPU_AllReduceOperation : I32EnumAttr<"AllReduceOperation",
     "built-in reduction operations supported by gpu.allreduce.",
     [
       GPU_AllReduceOpAdd,
-      GPU_AllReduceOpMul,
-      GPU_AllReduceOpMinUI,
-      GPU_AllReduceOpMinSI,
-      GPU_AllReduceOpMinF,
-      GPU_AllReduceOpMaxUI,
-      GPU_AllReduceOpMaxSI,
-      GPU_AllReduceOpMaxF,
       GPU_AllReduceOpAnd,
+      GPU_AllReduceOpMax,
+      GPU_AllReduceOpMin,
+      GPU_AllReduceOpMul,
       GPU_AllReduceOpOr,
-      GPU_AllReduceOpXor,
-      GPU_AllReduceOpMinimumF,
-      GPU_AllReduceOpMaximumF
+      GPU_AllReduceOpXor
     ]>{
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::gpu";
 }
-
-def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">;
-
 def GPU_AllReduceOperationAttr : EnumAttr<GPU_Dialect, GPU_AllReduceOperation,
                                           "all_reduce_op">;
 
 def GPU_AllReduceOp : GPU_Op<"all_reduce",
-    [SameOperandsAndResultType, IsolatedFromAbove]> {
+    [SameOperandsAndResultType, IsolatedFromAbove]>,
+    Arguments<(ins AnyType:$value,
+               OptionalAttr<GPU_AllReduceOperationAttr>:$op,
+               UnitAttr:$uniform)>,
+    Results<(outs AnyType)> {
   let summary = "Reduce values among workgroup.";
   let description = [{
     The `all_reduce` op reduces the value of every work item across a local
@@ -996,23 +981,12 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
 
     compute the sum of each work item's %0 value. The first version specifies
     the accumulation as operation, whereas the second version specifies the
-    accumulation as code region. The reduction operation must be one of:
-    *  Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
-       `or`, `xor`
-    *  Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
-       `maximumf`
+    accumulation as code region. The accumulation operation must be one of:
+    `add`, `and`, `max`, `min`, `mul`, `or`, `xor`.
 
     If `uniform` flag is set either none or all work items of a workgroup
     need to execute this op in convergence.
   }];
-
-  let arguments = (ins
-    AnyIntegerOrFloat:$value,
-    OptionalAttr<GPU_AllReduceOperationAttr>:$op,
-    UnitAttr:$uniform
-  );
-  let results = (outs AnyIntegerOrFloat:$result);
-
   let regions = (region AnyRegion:$body);
   let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
                           (`uniform` $uniform^)? $body attr-dict
@@ -1022,7 +996,12 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
   let hasRegionVerifier = 1;
 }
 
-def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
+def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
+    [SameOperandsAndResultType]>,
+    Arguments<(ins AnyType:$value,
+               GPU_AllReduceOperationAttr:$op,
+               UnitAttr:$uniform)>,
+    Results<(outs AnyType)> {
   let summary = "Reduce values among subgroup.";
   let description = [{
     The `subgroup_reduce` op reduces the value of every work item across a
@@ -1035,21 +1014,8 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
     ```
 
     If `uniform` flag is set either none or all work items of a subgroup
-    need to execute this op in convergence. The reduction operation must be one
-    of:
-    *  Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
-       `or`, `xor`
-    *  Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
-       `maximumf`
+    need to execute this op in convergence.
   }];
-
-  let arguments = (ins
-    AnyIntegerOrFloat:$value,
-    GPU_AllReduceOperationAttr:$op,
-    UnitAttr:$uniform
-  );
-  let results = (outs AnyIntegerOrFloat:$result);
-
   let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
                           (`uniform` $uniform^)? attr-dict
                           `:` functional-type(operands, results) }];

diff  --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 03180a687523bf7..b0b5348baaad963 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -34,7 +34,7 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
                                   !::llvm::cast<VectorType>($_self).isScalable()}]>;
 
 // Whether a type is a scalable VectorType.
-def IsVectorTypeWithAnyDimScalablePred
+def IsVectorTypeWithAnyDimScalablePred 
         : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
                   ::llvm::cast<VectorType>($_self).isScalable()}]>;
 

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 4855fd187eb5861..9456784c406aebb 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -65,28 +65,17 @@ convertReduxKind(gpu::AllReduceOperation mode) {
   switch (mode) {
   case gpu::AllReduceOperation::ADD:
     return NVVM::ReduxKind::ADD;
-  case gpu::AllReduceOperation::MUL:
-    return std::nullopt;
-  case gpu::AllReduceOperation::MINSI:
-    return NVVM::ReduxKind::MIN;
-  case gpu::AllReduceOperation::MINUI:
-    return std::nullopt;
-  case gpu::AllReduceOperation::MINF:
-    return NVVM::ReduxKind::MIN;
-  case gpu::AllReduceOperation::MAXSI:
-    return NVVM::ReduxKind::MAX;
-  case gpu::AllReduceOperation::MAXUI:
-    return std::nullopt;
-  case gpu::AllReduceOperation::MAXF:
-    return NVVM::ReduxKind::MAX;
   case gpu::AllReduceOperation::AND:
     return NVVM::ReduxKind::AND;
+  case gpu::AllReduceOperation::MAX:
+    return NVVM::ReduxKind::MAX;
+  case gpu::AllReduceOperation::MIN:
+    return NVVM::ReduxKind::MIN;
   case gpu::AllReduceOperation::OR:
     return NVVM::ReduxKind::OR;
   case gpu::AllReduceOperation::XOR:
     return NVVM::ReduxKind::XOR;
-  case gpu::AllReduceOperation::MINIMUMF:
-  case gpu::AllReduceOperation::MAXIMUMF:
+  case gpu::AllReduceOperation::MUL:
     return std::nullopt;
   }
   return std::nullopt;

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 6536bbe1f4dba47..693cc3f6236b574 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -503,53 +503,26 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
     return std::nullopt;
   }
 
-  // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
-  // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
-  // reduction ops. We should account possible precision requirements in this
-  // conversion.
-
   using ReduceType = gpu::AllReduceOperation;
+  namespace spv = spirv;
   const OpHandler handlers[] = {
       {ReduceType::ADD,
-       &createGroupReduceOpImpl<spirv::GroupIAddOp,
-                                spirv::GroupNonUniformIAddOp>,
-       &createGroupReduceOpImpl<spirv::GroupFAddOp,
-                                spirv::GroupNonUniformFAddOp>},
+       &createGroupReduceOpImpl<spv::GroupIAddOp, spv::GroupNonUniformIAddOp>,
+       &createGroupReduceOpImpl<spv::GroupFAddOp, spv::GroupNonUniformFAddOp>},
       {ReduceType::MUL,
-       &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
-                                spirv::GroupNonUniformIMulOp>,
-       &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
-                                spirv::GroupNonUniformFMulOp>},
-      {ReduceType::MINUI,
-       &createGroupReduceOpImpl<spirv::GroupUMinOp,
-                                spirv::GroupNonUniformUMinOp>,
-       nullptr},
-      {ReduceType::MINSI,
-       &createGroupReduceOpImpl<spirv::GroupSMinOp,
-                                spirv::GroupNonUniformSMinOp>,
-       nullptr},
-      {ReduceType::MINF, nullptr,
-       &createGroupReduceOpImpl<spirv::GroupFMinOp,
-                                spirv::GroupNonUniformFMinOp>},
-      {ReduceType::MAXUI,
-       &createGroupReduceOpImpl<spirv::GroupUMaxOp,
-                                spirv::GroupNonUniformUMaxOp>,
-       nullptr},
-      {ReduceType::MAXSI,
-       &createGroupReduceOpImpl<spirv::GroupSMaxOp,
-                                spirv::GroupNonUniformSMaxOp>,
-       nullptr},
-      {ReduceType::MAXF, nullptr,
-       &createGroupReduceOpImpl<spirv::GroupFMaxOp,
-                                spirv::GroupNonUniformFMaxOp>},
-      {ReduceType::MINIMUMF, nullptr,
-       &createGroupReduceOpImpl<spirv::GroupFMinOp,
-                                spirv::GroupNonUniformFMinOp>},
-      {ReduceType::MAXIMUMF, nullptr,
-       &createGroupReduceOpImpl<spirv::GroupFMaxOp,
-                                spirv::GroupNonUniformFMaxOp>}};
-
-  for (const OpHandler &handler : handlers)
+       &createGroupReduceOpImpl<spv::GroupIMulKHROp,
+                                spv::GroupNonUniformIMulOp>,
+       &createGroupReduceOpImpl<spv::GroupFMulKHROp,
+                                spv::GroupNonUniformFMulOp>},
+      {ReduceType::MIN,
+       &createGroupReduceOpImpl<spv::GroupSMinOp, spv::GroupNonUniformSMinOp>,
+       &createGroupReduceOpImpl<spv::GroupFMinOp, spv::GroupNonUniformFMinOp>},
+      {ReduceType::MAX,
+       &createGroupReduceOpImpl<spv::GroupSMaxOp, spv::GroupNonUniformSMaxOp>,
+       &createGroupReduceOpImpl<spv::GroupFMaxOp, spv::GroupNonUniformFMaxOp>},
+  };
+
+  for (auto &handler : handlers)
     if (handler.type == opType)
       return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
 

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index d31903ea201158f..1b6db1fb0c79f7c 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -27,9 +27,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/InliningUtils.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -488,23 +486,12 @@ static LogicalResult verifyAttributions(Operation *op,
 // AllReduceOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
-                                           Type resType) {
-  using Kind = gpu::AllReduceOperation;
-  if (llvm::is_contained(
-          {Kind::MINF, Kind::MAXF, Kind::MINIMUMF, Kind::MAXIMUMF}, opName)) {
-    if (!isa<FloatType>(resType))
-      return failure();
-  }
-
-  if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
-                          Kind::AND, Kind::OR, Kind::XOR},
-                         opName)) {
-    if (!isa<IntegerType>(resType))
-      return failure();
-  }
-
-  return success();
+static bool verifyReduceOpAndType(gpu::AllReduceOperation opName,
+                                  Type resType) {
+  return (opName != gpu::AllReduceOperation::AND &&
+          opName != gpu::AllReduceOperation::OR &&
+          opName != gpu::AllReduceOperation::XOR) ||
+         llvm::isa<IntegerType>(resType);
 }
 
 LogicalResult gpu::AllReduceOp::verifyRegions() {
@@ -531,13 +518,12 @@ LogicalResult gpu::AllReduceOp::verifyRegions() {
       return emitError("expected gpu.yield op in region");
   } else {
     gpu::AllReduceOperation opName = *getOp();
-    if (failed(verifyReduceOpAndType(opName, getType()))) {
-      return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
-                         << "` reduction operation is not compatible with type "
-                         << getType();
+    if (!verifyReduceOpAndType(opName, getType())) {
+      return emitError()
+             << '`' << gpu::stringifyAllReduceOperation(opName)
+             << "` accumulator is only compatible with Integer type";
     }
   }
-
   return success();
 }
 
@@ -588,10 +574,9 @@ static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
 
 LogicalResult gpu::SubgroupReduceOp::verify() {
   gpu::AllReduceOperation opName = getOp();
-  if (failed(verifyReduceOpAndType(opName, getType()))) {
+  if (!verifyReduceOpAndType(opName, getType())) {
     return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
-                       << "` reduction operation is not compatible with type "
-                       << getType();
+                       << "` accumulator is only compatible with Integer type";
   }
   return success();
 }

diff  --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index ecee9a7b45e32bd..acf4f6d0e3d6979 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -214,37 +214,32 @@ struct GpuAllReduceRewriter {
 
   /// Returns an accumulator factory that creates an op specified by opName.
   AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
-    using Kind = gpu::AllReduceOperation;
     bool isFloatingPoint = isa<FloatType>(valueType);
     switch (opName) {
-    case Kind::ADD:
+    case gpu::AllReduceOperation::ADD:
       return isFloatingPoint ? getFactory<arith::AddFOp>()
                              : getFactory<arith::AddIOp>();
-    case Kind::MUL:
+    case gpu::AllReduceOperation::MUL:
       return isFloatingPoint ? getFactory<arith::MulFOp>()
                              : getFactory<arith::MulIOp>();
-    case Kind::MINSI:
-      return getFactory<arith::MinSIOp>();
-    case Kind::MINUI:
-      return getFactory<arith::MinUIOp>();
-    case Kind::MINF:
-      return getFactory<arith::MinNumFOp>();
-    case Kind::MAXSI:
-      return getFactory<arith::MaxSIOp>();
-    case Kind::MAXUI:
-      return getFactory<arith::MaxUIOp>();
-    case Kind::MAXF:
-      return getFactory<arith::MaxNumFOp>();
-    case Kind::AND:
+    case gpu::AllReduceOperation::AND:
       return getFactory<arith::AndIOp>();
-    case Kind::OR:
+    case gpu::AllReduceOperation::OR:
       return getFactory<arith::OrIOp>();
-    case Kind::XOR:
+    case gpu::AllReduceOperation::XOR:
       return getFactory<arith::XOrIOp>();
-    case Kind::MINIMUMF:
-      return getFactory<arith::MinimumFOp>();
-    case Kind::MAXIMUMF:
-      return getFactory<arith::MaximumFOp>();
+    case gpu::AllReduceOperation::MAX:
+      return isFloatingPoint
+                 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
+                                 arith::CmpFPredicate::UGT>()
+                 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
+                                 arith::CmpIPredicate::ugt>();
+    case gpu::AllReduceOperation::MIN:
+      return isFloatingPoint
+                 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
+                                 arith::CmpFPredicate::ULT>()
+                 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
+                                 arith::CmpIPredicate::ult>();
     }
     llvm_unreachable("unknown GPU AllReduceOperation");
   }
@@ -252,11 +247,21 @@ struct GpuAllReduceRewriter {
   /// Returns an accumulator factory that creates an op of type T.
   template <typename T>
   AccumulatorFactory getFactory() {
-    return [this](Value lhs, Value rhs) {
+    return [&](Value lhs, Value rhs) {
       return create<T>(lhs.getType(), lhs, rhs);
     };
   }
 
+  /// Returns an accumulator for comparison such as min, max. T is the type
+  /// of the compare op.
+  template <typename T, typename PredicateEnum, PredicateEnum predicate>
+  AccumulatorFactory getCmpFactory() const {
+    return [&](Value lhs, Value rhs) {
+      Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
+      return rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
+    };
+  }
+
   /// Creates an if-block skeleton and calls the two factories to generate the
   /// ops in the `then` and `else` block..
   ///

diff  --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 20a200e812c1259..c18bb423a6e6001 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -582,22 +582,22 @@ gpu.module @test_module_30 {
     %result = gpu.subgroup_reduce add %arg0 uniform {} : (i32) -> (i32)
     gpu.return
   }
-  // CHECK-LABEL: @subgroup_reduce_minsi
-  gpu.func @subgroup_reduce_minsi(%arg0 : i32) {
-    // CHECK: nvvm.redux.sync min {{.*}}
-    %result = gpu.subgroup_reduce minsi %arg0 uniform {} : (i32) -> (i32)
+  // CHECK-LABEL: func @subgroup_reduce_and
+  gpu.func @subgroup_reduce_and(%arg0 : i32) {
+    // CHECK: nvvm.redux.sync and {{.*}}
+    %result = gpu.subgroup_reduce and %arg0 uniform {} : (i32) -> (i32)
     gpu.return
   }
-  // CHECK-LABEL:  @subgroup_reduce_maxsi
-  gpu.func @subgroup_reduce_maxsi(%arg0 : i32) {
+  // CHECK-LABEL:  @subgroup_reduce_max
+  gpu.func @subgroup_reduce_max(%arg0 : i32) {
     // CHECK: nvvm.redux.sync max {{.*}}
-    %result = gpu.subgroup_reduce maxsi %arg0 uniform {} : (i32) -> (i32)
+    %result = gpu.subgroup_reduce max %arg0 uniform {} : (i32) -> (i32)
     gpu.return
   }
-  // CHECK-LABEL: func @subgroup_reduce_and
-  gpu.func @subgroup_reduce_and(%arg0 : i32) {
-    // CHECK: nvvm.redux.sync and {{.*}}
-    %result = gpu.subgroup_reduce and %arg0 uniform {} : (i32) -> (i32)
+  // CHECK-LABEL: @subgroup_reduce_min
+  gpu.func @subgroup_reduce_min(%arg0 : i32) {
+    // CHECK: nvvm.redux.sync min {{.*}}
+    %result = gpu.subgroup_reduce min %arg0 uniform {} : (i32) -> (i32)
     gpu.return
   }
   // CHECK-LABEL:  @subgroup_reduce_or

diff  --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index feb7ee185a8a858..1e5d64387650ce3 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -331,7 +331,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupFMin <Workgroup> <Reduce> %[[ARG]] : f32
-    %reduced = gpu.all_reduce minf %arg uniform {} : (f32) -> (f32)
+    %reduced = gpu.all_reduce min %arg uniform {} : (f32) -> (f32)
     gpu.return
   }
 }
@@ -351,7 +351,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformFMin "Workgroup" "Reduce" %[[ARG]] : f32
-    %reduced = gpu.all_reduce minf %arg {} : (f32) -> (f32)
+    %reduced = gpu.all_reduce min %arg {} : (f32) -> (f32)
     gpu.return
   }
 }
@@ -371,9 +371,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupSMin <Workgroup> <Reduce> %[[ARG]] : i32
-    // CHECK: %{{.*}} = spirv.GroupUMin <Workgroup> <Reduce> %[[ARG]] : i32
-    %r0 = gpu.all_reduce minsi %arg uniform {} : (i32) -> (i32)
-    %r1 = gpu.all_reduce minui %arg uniform {} : (i32) -> (i32)
+    %reduced = gpu.all_reduce min %arg uniform {} : (i32) -> (i32)
     gpu.return
   }
 }
@@ -392,9 +390,8 @@ gpu.module @kernels {
   //  CHECK-SAME: (%[[ARG:.*]]: i32)
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
-    // CHECK: %{{.*}} = spirv.GroupNonUniformUMin "Workgroup" "Reduce" %[[ARG]] : i32
-    %r0 = gpu.all_reduce minsi %arg {} : (i32) -> (i32)
-    %r1 = gpu.all_reduce minui %arg {} : (i32) -> (i32)
+    // CHECK: %{{.*}} = spirv.GroupNonUniformSMin "Workgroup" "Reduce" %[[ARG]] : i32
+    %reduced = gpu.all_reduce min %arg {} : (i32) -> (i32)
     gpu.return
   }
 }
@@ -414,7 +411,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupFMin <Subgroup> <Reduce> %[[ARG]] : f32
-    %reduced = gpu.subgroup_reduce minf %arg uniform : (f32) -> (f32)
+    %reduced = gpu.subgroup_reduce min %arg uniform : (f32) -> (f32)
     gpu.return
   }
 }
@@ -434,7 +431,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformFMin "Subgroup" "Reduce" %[[ARG]] : f32
-    %reduced = gpu.subgroup_reduce minf %arg : (f32) -> (f32)
+    %reduced = gpu.subgroup_reduce min %arg : (f32) -> (f32)
     gpu.return
   }
 }
@@ -454,9 +451,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupSMin <Subgroup> <Reduce> %[[ARG]] : i32
-    // CHECK: %{{.*}} = spirv.GroupUMin <Subgroup> <Reduce> %[[ARG]] : i32
-    %r0 = gpu.subgroup_reduce minsi %arg uniform : (i32) -> (i32)
-    %r1 = gpu.subgroup_reduce minui %arg uniform : (i32) -> (i32)
+    %reduced = gpu.subgroup_reduce min %arg uniform : (i32) -> (i32)
     gpu.return
   }
 }
@@ -476,9 +471,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformSMin "Subgroup" "Reduce" %[[ARG]] : i32
-    // CHECK: %{{.*}} = spirv.GroupNonUniformUMin "Subgroup" "Reduce" %[[ARG]] : i32
-    %r0 = gpu.subgroup_reduce minsi %arg : (i32) -> (i32)
-    %r1 = gpu.subgroup_reduce minui %arg : (i32) -> (i32)
+    %reduced = gpu.subgroup_reduce min %arg : (i32) -> (i32)
     gpu.return
   }
 }
@@ -498,7 +491,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupFMax <Workgroup> <Reduce> %[[ARG]] : f32
-    %reduced = gpu.all_reduce maxf %arg uniform {} : (f32) -> (f32)
+    %reduced = gpu.all_reduce max %arg uniform {} : (f32) -> (f32)
     gpu.return
   }
 }
@@ -518,7 +511,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformFMax "Workgroup" "Reduce" %[[ARG]] : f32
-    %reduced = gpu.all_reduce maxf %arg {} : (f32) -> (f32)
+    %reduced = gpu.all_reduce max %arg {} : (f32) -> (f32)
     gpu.return
   }
 }
@@ -538,9 +531,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupSMax <Workgroup> <Reduce> %[[ARG]] : i32
-    // CHECK: %{{.*}} = spirv.GroupUMax <Workgroup> <Reduce> %[[ARG]] : i32
-    %r0 = gpu.all_reduce maxsi %arg uniform {} : (i32) -> (i32)
-    %r1 = gpu.all_reduce maxui %arg uniform {} : (i32) -> (i32)
+    %reduced = gpu.all_reduce max %arg uniform {} : (i32) -> (i32)
     gpu.return
   }
 }
@@ -560,9 +551,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformSMax "Workgroup" "Reduce" %[[ARG]] : i32
-    // CHECK: %{{.*}} = spirv.GroupNonUniformUMax "Workgroup" "Reduce" %[[ARG]] : i32
-    %r0 = gpu.all_reduce maxsi %arg {} : (i32) -> (i32)
-    %r1 = gpu.all_reduce maxui %arg {} : (i32) -> (i32)
+    %reduced = gpu.all_reduce max %arg {} : (i32) -> (i32)
     gpu.return
   }
 }
@@ -582,7 +571,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupFMax <Subgroup> <Reduce> %[[ARG]] : f32
-    %reduced = gpu.subgroup_reduce maxf %arg uniform : (f32) -> (f32)
+    %reduced = gpu.subgroup_reduce max %arg uniform : (f32) -> (f32)
     gpu.return
   }
 }
@@ -602,7 +591,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformFMax "Subgroup" "Reduce" %[[ARG]] : f32
-    %reduced = gpu.subgroup_reduce maxf %arg : (f32) -> (f32)
+    %reduced = gpu.subgroup_reduce max %arg : (f32) -> (f32)
     gpu.return
   }
 }
@@ -622,9 +611,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupSMax <Subgroup> <Reduce> %[[ARG]] : i32
-    // CHECK: %{{.*}} = spirv.GroupUMax <Subgroup> <Reduce> %[[ARG]] : i32
-    %r0 = gpu.subgroup_reduce maxsi %arg uniform : (i32) -> (i32)
-    %r1 = gpu.subgroup_reduce maxui %arg uniform : (i32) -> (i32)
+    %reduced = gpu.subgroup_reduce max %arg uniform : (i32) -> (i32)
     gpu.return
   }
 }
@@ -644,9 +631,7 @@ gpu.module @kernels {
   gpu.func @test(%arg : i32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformSMax "Subgroup" "Reduce" %[[ARG]] : i32
-    // CHECK: %{{.*}} = spirv.GroupNonUniformUMax "Subgroup" "Reduce" %[[ARG]] : i32
-    %r0 = gpu.subgroup_reduce maxsi %arg : (i32) -> (i32)
-    %r1 = gpu.subgroup_reduce maxui %arg : (i32) -> (i32)
+    %reduced = gpu.subgroup_reduce max %arg : (i32) -> (i32)
     gpu.return
   }
 }

diff  --git a/mlir/test/Dialect/GPU/all-reduce-maxf.mlir b/mlir/test/Dialect/GPU/all-reduce-max.mlir
similarity index 69%
rename from mlir/test/Dialect/GPU/all-reduce-maxf.mlir
rename to mlir/test/Dialect/GPU/all-reduce-max.mlir
index b502e587637cdc8..a71544ba0e98d36 100644
--- a/mlir/test/Dialect/GPU/all-reduce-maxf.mlir
+++ b/mlir/test/Dialect/GPU/all-reduce-max.mlir
@@ -44,55 +44,65 @@ gpu.module @kernels {
     // CHECK:   [[VAL_34:%.*]], [[VAL_35:%.*]] = gpu.shuffle xor [[VAL_0]], [[VAL_6]], [[VAL_32]] : f32
     // CHECK:   cf.cond_br [[VAL_35]], ^bb2, ^bb3
     // CHECK: ^bb2:
-    // CHECK:   [[VAL_36:%.*]] = arith.maxnumf [[VAL_0]], [[VAL_34]] : f32
-    // CHECK:   cf.br ^bb4([[VAL_36]] : f32)
+    // CHECK:   [[VAL_36:%.*]] = arith.cmpf ugt, [[VAL_0]], [[VAL_34]] : f32
+    // CHECK:   [[VAL_37:%.*]] = arith.select [[VAL_36]], [[VAL_0]], [[VAL_34]] : f32
+    // CHECK:   cf.br ^bb4([[VAL_37]] : f32)
     // CHECK: ^bb3:
     // CHECK:   cf.br ^bb4([[VAL_0]] : f32)
     // CHECK: ^bb4([[VAL_38:%.*]]: f32):
     // CHECK:   [[VAL_39:%.*]], [[VAL_40:%.*]] = gpu.shuffle xor [[VAL_38]], [[VAL_7]], [[VAL_32]] : f32
     // CHECK:   cf.cond_br [[VAL_40]], ^bb5, ^bb6
     // CHECK: ^bb5:
-    // CHECK:   [[VAL_41:%.*]] = arith.maxnumf [[VAL_38]], [[VAL_39]] : f32
-    // CHECK:   cf.br ^bb7([[VAL_41]] : f32)
+    // CHECK:   [[VAL_41:%.*]] = arith.cmpf ugt, [[VAL_38]], [[VAL_39]] : f32
+    // CHECK:   [[VAL_42:%.*]] = arith.select [[VAL_41]], [[VAL_38]], [[VAL_39]] : f32
+    // CHECK:   cf.br ^bb7([[VAL_42]] : f32)
     // CHECK: ^bb6:
     // CHECK:   cf.br ^bb7([[VAL_38]] : f32)
     // CHECK: ^bb7([[VAL_43:%.*]]: f32):
     // CHECK:   [[VAL_44:%.*]], [[VAL_45:%.*]] = gpu.shuffle xor [[VAL_43]], [[VAL_8]], [[VAL_32]] : f32
     // CHECK:   cf.cond_br [[VAL_45]], ^bb8, ^bb9
     // CHECK: ^bb8:
-    // CHECK:   [[VAL_46:%.*]] = arith.maxnumf [[VAL_43]], [[VAL_44]] : f32
-    // CHECK:   cf.br ^bb10([[VAL_46]] : f32)
+    // CHECK:   [[VAL_46:%.*]] = arith.cmpf ugt, [[VAL_43]], [[VAL_44]] : f32
+    // CHECK:   [[VAL_47:%.*]] = arith.select [[VAL_46]], [[VAL_43]], [[VAL_44]] : f32
+    // CHECK:   cf.br ^bb10([[VAL_47]] : f32)
     // CHECK: ^bb9:
     // CHECK:   cf.br ^bb10([[VAL_43]] : f32)
     // CHECK: ^bb10([[VAL_48:%.*]]: f32):
     // CHECK:   [[VAL_49:%.*]], [[VAL_50:%.*]] = gpu.shuffle xor [[VAL_48]], [[VAL_9]], [[VAL_32]] : f32
     // CHECK:   cf.cond_br [[VAL_50]], ^bb11, ^bb12
     // CHECK: ^bb11:
-    // CHECK:   [[VAL_51:%.*]] = arith.maxnumf [[VAL_48]], [[VAL_49]] : f32
-    // CHECK:   cf.br ^bb13([[VAL_51]] : f32)
+    // CHECK:   [[VAL_51:%.*]] = arith.cmpf ugt, [[VAL_48]], [[VAL_49]] : f32
+    // CHECK:   [[VAL_52:%.*]] = arith.select [[VAL_51]], [[VAL_48]], [[VAL_49]] : f32
+    // CHECK:   cf.br ^bb13([[VAL_52]] : f32)
     // CHECK: ^bb12:
     // CHECK:   cf.br ^bb13([[VAL_48]] : f32)
     // CHECK: ^bb13([[VAL_53:%.*]]: f32):
     // CHECK:   [[VAL_54:%.*]], [[VAL_55:%.*]] = gpu.shuffle xor [[VAL_53]], [[VAL_10]], [[VAL_32]] : f32
     // CHECK:   cf.cond_br [[VAL_55]], ^bb14, ^bb15
     // CHECK: ^bb14:
-    // CHECK:   [[VAL_56:%.*]] = arith.maxnumf [[VAL_53]], [[VAL_54]] : f32
-    // CHECK:   cf.br ^bb16([[VAL_56]] : f32)
+    // CHECK:   [[VAL_56:%.*]] = arith.cmpf ugt, [[VAL_53]], [[VAL_54]] : f32
+    // CHECK:   [[VAL_57:%.*]] = arith.select [[VAL_56]], [[VAL_53]], [[VAL_54]] : f32
+    // CHECK:   cf.br ^bb16([[VAL_57]] : f32)
     // CHECK: ^bb15:
     // CHECK:   cf.br ^bb16([[VAL_53]] : f32)
     // CHECK: ^bb16([[VAL_58:%.*]]: f32):
     // CHECK:   cf.br ^bb18([[VAL_58]] : f32)
     // CHECK: ^bb17:
     // CHECK:   [[VAL_59:%.*]], [[VAL_60:%.*]] = gpu.shuffle xor [[VAL_0]], [[VAL_6]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_62:%.*]] = arith.maxnumf [[VAL_0]], [[VAL_59]] : f32
+    // CHECK:   [[VAL_61:%.*]] = arith.cmpf ugt, [[VAL_0]], [[VAL_59]] : f32
+    // CHECK:   [[VAL_62:%.*]] = arith.select [[VAL_61]], [[VAL_0]], [[VAL_59]] : f32
     // CHECK:   [[VAL_63:%.*]], [[VAL_64:%.*]] = gpu.shuffle xor [[VAL_62]], [[VAL_7]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_66:%.*]] = arith.maxnumf [[VAL_62]], [[VAL_63]] : f32
+    // CHECK:   [[VAL_65:%.*]] = arith.cmpf ugt, [[VAL_62]], [[VAL_63]] : f32
+    // CHECK:   [[VAL_66:%.*]] = arith.select [[VAL_65]], [[VAL_62]], [[VAL_63]] : f32
     // CHECK:   [[VAL_67:%.*]], [[VAL_68:%.*]] = gpu.shuffle xor [[VAL_66]], [[VAL_8]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_70:%.*]] = arith.maxnumf [[VAL_66]], [[VAL_67]] : f32
+    // CHECK:   [[VAL_69:%.*]] = arith.cmpf ugt, [[VAL_66]], [[VAL_67]] : f32
+    // CHECK:   [[VAL_70:%.*]] = arith.select [[VAL_69]], [[VAL_66]], [[VAL_67]] : f32
     // CHECK:   [[VAL_71:%.*]], [[VAL_72:%.*]] = gpu.shuffle xor [[VAL_70]], [[VAL_9]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_74:%.*]] = arith.maxnumf [[VAL_70]], [[VAL_71]] : f32
+    // CHECK:   [[VAL_73:%.*]] = arith.cmpf ugt, [[VAL_70]], [[VAL_71]] : f32
+    // CHECK:   [[VAL_74:%.*]] = arith.select [[VAL_73]], [[VAL_70]], [[VAL_71]] : f32
     // CHECK:   [[VAL_75:%.*]], [[VAL_76:%.*]] = gpu.shuffle xor [[VAL_74]], [[VAL_10]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_78:%.*]] = arith.maxnumf [[VAL_74]], [[VAL_75]] : f32
+    // CHECK:   [[VAL_77:%.*]] = arith.cmpf ugt, [[VAL_74]], [[VAL_75]] : f32
+    // CHECK:   [[VAL_78:%.*]] = arith.select [[VAL_77]], [[VAL_74]], [[VAL_75]] : f32
     // CHECK:   cf.br ^bb18([[VAL_78]] : f32)
     // CHECK: ^bb18([[VAL_79:%.*]]: f32):
     // CHECK:   cf.cond_br [[VAL_30]], ^bb19, ^bb20
@@ -118,7 +128,8 @@ gpu.module @kernels {
     // CHECK:   [[VAL_88:%.*]], [[VAL_89:%.*]] = gpu.shuffle xor [[VAL_86]], [[VAL_6]], [[VAL_83]] : f32
     // CHECK:   cf.cond_br [[VAL_89]], ^bb24, ^bb25
     // CHECK: ^bb24:
-    // CHECK:   [[VAL_91:%.*]] = arith.maxnumf [[VAL_86]], [[VAL_88]] : f32
+    // CHECK:   [[VAL_90:%.*]] = arith.cmpf ugt, [[VAL_86]], [[VAL_88]] : f32
+    // CHECK:   [[VAL_91:%.*]] = arith.select [[VAL_90]], [[VAL_86]], [[VAL_88]] : f32
     // CHECK:   cf.br ^bb26([[VAL_91]] : f32)
     // CHECK: ^bb25:
     // CHECK:   cf.br ^bb26([[VAL_86]] : f32)
@@ -126,7 +137,8 @@ gpu.module @kernels {
     // CHECK:   [[VAL_93:%.*]], [[VAL_94:%.*]] = gpu.shuffle xor [[VAL_92]], [[VAL_7]], [[VAL_83]] : f32
     // CHECK:   cf.cond_br [[VAL_94]], ^bb27, ^bb28
     // CHECK: ^bb27:
-    // CHECK:   [[VAL_96:%.*]] = arith.maxnumf [[VAL_92]], [[VAL_93]] : f32
+    // CHECK:   [[VAL_95:%.*]] = arith.cmpf ugt, [[VAL_92]], [[VAL_93]] : f32
+    // CHECK:   [[VAL_96:%.*]] = arith.select [[VAL_95]], [[VAL_92]], [[VAL_93]] : f32
     // CHECK:   cf.br ^bb29([[VAL_96]] : f32)
     // CHECK: ^bb28:
     // CHECK:   cf.br ^bb29([[VAL_92]] : f32)
@@ -134,7 +146,8 @@ gpu.module @kernels {
     // CHECK:   [[VAL_98:%.*]], [[VAL_99:%.*]] = gpu.shuffle xor [[VAL_97]], [[VAL_8]], [[VAL_83]] : f32
     // CHECK:   cf.cond_br [[VAL_99]], ^bb30, ^bb31
     // CHECK: ^bb30:
-    // CHECK:   [[VAL_101:%.*]] = arith.maxnumf [[VAL_97]], [[VAL_98]] : f32
+    // CHECK:   [[VAL_100:%.*]] = arith.cmpf ugt, [[VAL_97]], [[VAL_98]] : f32
+    // CHECK:   [[VAL_101:%.*]] = arith.select [[VAL_100]], [[VAL_97]], [[VAL_98]] : f32
     // CHECK:   cf.br ^bb32([[VAL_101]] : f32)
     // CHECK: ^bb31:
     // CHECK:   cf.br ^bb32([[VAL_97]] : f32)
@@ -142,7 +155,8 @@ gpu.module @kernels {
     // CHECK:   [[VAL_103:%.*]], [[VAL_104:%.*]] = gpu.shuffle xor [[VAL_102]], [[VAL_9]], [[VAL_83]] : f32
     // CHECK:   cf.cond_br [[VAL_104]], ^bb33, ^bb34
     // CHECK: ^bb33:
-    // CHECK:   [[VAL_106:%.*]] = arith.maxnumf [[VAL_102]], [[VAL_103]] : f32
+    // CHECK:   [[VAL_105:%.*]] = arith.cmpf ugt, [[VAL_102]], [[VAL_103]] : f32
+    // CHECK:   [[VAL_106:%.*]] = arith.select [[VAL_105]], [[VAL_102]], [[VAL_103]] : f32
     // CHECK:   cf.br ^bb35([[VAL_106]] : f32)
     // CHECK: ^bb34:
     // CHECK:   cf.br ^bb35([[VAL_102]] : f32)
@@ -150,7 +164,8 @@ gpu.module @kernels {
     // CHECK:   [[VAL_108:%.*]], [[VAL_109:%.*]] = gpu.shuffle xor [[VAL_107]], [[VAL_10]], [[VAL_83]] : f32
     // CHECK:   cf.cond_br [[VAL_109]], ^bb36, ^bb37
     // CHECK: ^bb36:
-    // CHECK:   [[VAL_111:%.*]] = arith.maxnumf [[VAL_107]], [[VAL_108]] : f32
+    // CHECK:   [[VAL_110:%.*]] = arith.cmpf ugt, [[VAL_107]], [[VAL_108]] : f32
+    // CHECK:   [[VAL_111:%.*]] = arith.select [[VAL_110]], [[VAL_107]], [[VAL_108]] : f32
     // CHECK:   cf.br ^bb38([[VAL_111]] : f32)
     // CHECK: ^bb37:
     // CHECK:   cf.br ^bb38([[VAL_107]] : f32)
@@ -158,15 +173,20 @@ gpu.module @kernels {
     // CHECK:   cf.br ^bb40([[VAL_112]] : f32)
     // CHECK: ^bb39:
     // CHECK:   [[VAL_113:%.*]], [[VAL_114:%.*]] = gpu.shuffle xor [[VAL_86]], [[VAL_6]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_116:%.*]] = arith.maxnumf [[VAL_86]], [[VAL_113]] : f32
+    // CHECK:   [[VAL_115:%.*]] = arith.cmpf ugt, [[VAL_86]], [[VAL_113]] : f32
+    // CHECK:   [[VAL_116:%.*]] = arith.select [[VAL_115]], [[VAL_86]], [[VAL_113]] : f32
     // CHECK:   [[VAL_117:%.*]], [[VAL_118:%.*]] = gpu.shuffle xor [[VAL_116]], [[VAL_7]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_120:%.*]] = arith.maxnumf [[VAL_116]], [[VAL_117]] : f32
+    // CHECK:   [[VAL_119:%.*]] = arith.cmpf ugt, [[VAL_116]], [[VAL_117]] : f32
+    // CHECK:   [[VAL_120:%.*]] = arith.select [[VAL_119]], [[VAL_116]], [[VAL_117]] : f32
     // CHECK:   [[VAL_121:%.*]], [[VAL_122:%.*]] = gpu.shuffle xor [[VAL_120]], [[VAL_8]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_124:%.*]] = arith.maxnumf [[VAL_120]], [[VAL_121]] : f32
+    // CHECK:   [[VAL_123:%.*]] = arith.cmpf ugt, [[VAL_120]], [[VAL_121]] : f32
+    // CHECK:   [[VAL_124:%.*]] = arith.select [[VAL_123]], [[VAL_120]], [[VAL_121]] : f32
     // CHECK:   [[VAL_125:%.*]], [[VAL_126:%.*]] = gpu.shuffle xor [[VAL_124]], [[VAL_9]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_128:%.*]] = arith.maxnumf [[VAL_124]], [[VAL_125]] : f32
+    // CHECK:   [[VAL_127:%.*]] = arith.cmpf ugt, [[VAL_124]], [[VAL_125]] : f32
+    // CHECK:   [[VAL_128:%.*]] = arith.select [[VAL_127]], [[VAL_124]], [[VAL_125]] : f32
     // CHECK:   [[VAL_129:%.*]], [[VAL_130:%.*]] = gpu.shuffle xor [[VAL_128]], [[VAL_10]], [[VAL_5]] : f32
-    // CHECK:   [[VAL_132:%.*]] = arith.maxnumf [[VAL_128]], [[VAL_129]] : f32
+    // CHECK:   [[VAL_131:%.*]] = arith.cmpf ugt, [[VAL_128]], [[VAL_129]] : f32
+    // CHECK:   [[VAL_132:%.*]] = arith.select [[VAL_131]], [[VAL_128]], [[VAL_129]] : f32
     // CHECK:   cf.br ^bb40([[VAL_132]] : f32)
     // CHECK: ^bb40([[VAL_133:%.*]]: f32):
     // CHECK:   store [[VAL_133]], [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, #gpu.address_space<workgroup>>
@@ -175,7 +195,7 @@ gpu.module @kernels {
     // CHECK:   cf.br ^bb42
     // CHECK: ^bb42:
     // CHECK:   gpu.barrier
-    %sum = gpu.all_reduce maxf %arg0 uniform {} : (f32) -> (f32)
+    %sum = gpu.all_reduce max %arg0 uniform {} : (f32) -> (f32)
     gpu.return
   }
 

diff  --git a/mlir/test/Dialect/GPU/all-reduce-add.mlir b/mlir/test/Dialect/GPU/all-reduce.mlir
similarity index 100%
rename from mlir/test/Dialect/GPU/all-reduce-add.mlir
rename to mlir/test/Dialect/GPU/all-reduce.mlir

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 17faccbd091a8bc..3a2197ad4d5a172 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -210,14 +210,6 @@ module attributes {gpu.container_module} {
 
 // -----
 
-func.func @reduce_bad_type(%arg0 : vector<4xf32>) {
-  // expected-error at +1 {{'gpu.all_reduce' op operand #0 must be Integer or Float}}
-  %res = gpu.all_reduce add %arg0 {} : (vector<4xf32>) -> vector<4xf32>
-  return
-}
-
-// -----
-
 func.func @reduce_no_op_no_body(%arg0 : f32) {
   // expected-error at +1 {{expected either an op attribute or a non-empty body}}
   %res = "gpu.all_reduce"(%arg0) ({}) : (f32) -> (f32)
@@ -245,118 +237,22 @@ func.func @reduce_invalid_op(%arg0 : f32) {
 
 // -----
 
-func.func @reduce_invalid_op_type_minsi(%arg0 : f32) {
-  // expected-error at +1 {{`minsi` reduction operation is not compatible with type 'f32'}}
-  %res = gpu.all_reduce minsi %arg0 {} : (f32) -> (f32)
-  return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_minui(%arg0 : f32) {
-  // expected-error at +1 {{`minui` reduction operation is not compatible with type 'f32'}}
-  %res = gpu.all_reduce minui %arg0 {} : (f32) -> (f32)
-  return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_maxsi(%arg0 : f32) {
-  // expected-error at +1 {{`maxsi` reduction operation is not compatible with type 'f32'}}
-  %res = gpu.all_reduce maxsi %arg0 {} : (f32) -> (f32)
-  return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_maxui(%arg0 : f32) {
-  // expected-error at +1 {{`maxui` reduction operation is not compatible with type 'f32'}}
-  %res = gpu.all_reduce maxui %arg0 {} : (f32) -> (f32)
-  return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_and(%arg0 : f32) {
-  // expected-error at +1 {{`and` reduction operation is not compatible with type 'f32'}}
+func.func @reduce_invalid_op_type(%arg0 : f32) {
+  // expected-error at +1 {{`and` accumulator is only compatible with Integer type}}
   %res = gpu.all_reduce and %arg0 {} : (f32) -> (f32)
   return
 }
 
 // -----
 
-func.func @reduce_invalid_op_type_or(%arg0 : f32) {
-  // expected-error at +1 {{`or` reduction operation is not compatible with type 'f32'}}
-  %res = gpu.all_reduce or %arg0 {} : (f32) -> (f32)
-  return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_xor(%arg0 : f32) {
-  // expected-error at +1 {{`xor` reduction operation is not compatible with type 'f32'}}
-  %res = gpu.all_reduce xor %arg0 {} : (f32) -> (f32)
-  return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_minf(%arg0 : i32) {
-  // expected-error at +1 {{`minf` reduction operation is not compatible with type 'i32'}}
-  %res = gpu.all_reduce minf %arg0 {} : (i32) -> (i32)
-  return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_maxf(%arg0 : i32) {
-  // expected-error at +1 {{`maxf` reduction operation is not compatible with type 'i32'}}
-  %res = gpu.all_reduce maxf %arg0 {} : (i32) -> (i32)
-  return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_minimumf(%arg0 : i32) {
-  // expected-error at +1 {{`minimumf` reduction operation is not compatible with type 'i32'}}
-  %res = gpu.all_reduce minimumf %arg0 {} : (i32) -> (i32)
-  return
-}
-
-// -----
-
-func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {
-  // expected-error at +1 {{`maximumf` reduction operation is not compatible with type 'i32'}}
-  %res = gpu.all_reduce maximumf %arg0 {} : (i32) -> (i32)
-  return
-}
-
-// -----
-
-func.func @subgroup_reduce_bad_type(%arg0 : vector<2xf32>) {
-  // expected-error at +1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float}}
-  %res = gpu.subgroup_reduce add %arg0 : (vector<2xf32>) -> vector<2xf32>
-  return
-}
-
-// -----
-
-func.func @subgroup_reduce_invalid_op_type_and(%arg0 : f32) {
-  // expected-error at +1 {{`and` reduction operation is not compatible with type 'f32'}}
+func.func @subgroup_reduce_invalid_op_type(%arg0 : f32) {
+  // expected-error at +1 {{`and` accumulator is only compatible with Integer type}}
   %res = gpu.subgroup_reduce and %arg0 : (f32) -> (f32)
   return
 }
 
 // -----
 
-func.func @subgroup_reduce_invalid_op_type_maxf(%arg0 : i32) {
-  // expected-error at +1 {{`maxf` reduction operation is not compatible with type 'i32'}}
-  %res = gpu.subgroup_reduce maxf %arg0 : (i32) -> (i32)
-  return
-}
-
-// -----
-
 func.func @reduce_incorrect_region_arguments(%arg0 : f32) {
   // expected-error at +1 {{expected two region arguments}}
   %res = gpu.all_reduce %arg0 {
@@ -751,11 +647,11 @@ func.func @main() {
   %shmemSize = arith.constant 10000 : i32
   %c1 = arith.constant 1 : index
   gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
-             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
              dynamic_shared_memory_size %shmemSize
   {
     // expected-error @below {{'gpu.dynamic_shared_memory' op address space must be address_space<workgroup>}}
-    %0 = gpu.dynamic_shared_memory : memref<?xi8>
+    %0 = gpu.dynamic_shared_memory : memref<?xi8>  
     gpu.terminator
   }
   return
@@ -768,11 +664,11 @@ func.func @main() {
   %shmemSize = arith.constant 8192 : i32
   %c1 = arith.constant 1 : index
   gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
-             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
              dynamic_shared_memory_size %shmemSize
   {
     // expected-error @below {{'gpu.dynamic_shared_memory' op result memref type must be memref<?xi8, #gpu.address_space<workgroup>>}}
-    %0 = gpu.dynamic_shared_memory : memref<1xi8, #gpu.address_space<workgroup>>
+    %0 = gpu.dynamic_shared_memory : memref<1xi8, #gpu.address_space<workgroup>>  
     gpu.terminator
   }
   return
@@ -784,7 +680,7 @@ func.func @main(%arg0 : index) {
   %shmemSize = arith.constant 8192 : i32
   %c1 = arith.constant 1 : index
   gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
-             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
              dynamic_shared_memory_size %shmemSize
   {
     // expected-error @below {{'gpu.dynamic_shared_memory' op address space must be address_space<workgroup>}}
@@ -800,7 +696,7 @@ func.func @main(%arg0 : index) {
   %shmemSize = arith.constant 8192 : i32
   %c1 = arith.constant 1 : index
   gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
-             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
              dynamic_shared_memory_size %shmemSize
   {
     // expected-error @below {{'gpu.dynamic_shared_memory' op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, #gpu.address_space<workgroup>}}


        


More information about the Mlir-commits mailing list