[Mlir-commits] [mlir] [mlir][gpu] Align reduction operations with vector combining kinds (PR #73423)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Nov 25 20:29:20 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
The motivation for this change is explained in
https://github.com/llvm/llvm-project/issues/72354.
Before this change, we could not tell between signed/unsigned minimum/maximum and NaN treatment for floating point values.
The mapping of old reduction operations to the new ones is as follows:
* `min` --> `minsi` for ints, `minf` for floats
* `max` --> `maxsi` for ints, `maxf` for floats
New reduction kinds not represented in the old enum: `minui`, `maxui`, `minimumf`, `maximumf`.
As a next step, I would like to have a common definition of combining kinds used by the `vector` and `gpu` dialects. Separately, the GPU to SPIR-V lowering does not yet properly handle zero and NaN values -- the behavior of floating point min/max group reductions is not specified by the SPIR-V spec.
Issue: https://github.com/llvm/llvm-project/issues/72354
---
Patch is 41.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73423.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+38-14)
- (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+16-5)
- (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp (+42-16)
- (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+27-12)
- (modified) mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp (+23-28)
- (modified) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir (+11-11)
- (modified) mlir/test/Conversion/GPUToSPIRV/reductions.mlir (+32-17)
- (renamed) mlir/test/Dialect/GPU/all-reduce-add.mlir ()
- (renamed) mlir/test/Dialect/GPU/all-reduce-maxf.mlir (+26-46)
- (modified) mlir/test/Dialect/GPU/invalid.mlir (+98-10)
``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index e11c5c393648de7..424e30a32830890 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -868,25 +868,41 @@ def GPU_YieldOp : GPU_Op<"yield", [Pure, Terminator]>,
}];
}
-// add, mul mirror the XLA ComparisonDirection enum.
+// These mirror the reduction combining kinds from the vector dialect.
def GPU_AllReduceOpAdd : I32EnumAttrCase<"ADD", 0, "add">;
-def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 1, "and">;
-def GPU_AllReduceOpMax : I32EnumAttrCase<"MAX", 2, "max">;
-def GPU_AllReduceOpMin : I32EnumAttrCase<"MIN", 3, "min">;
-def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 4, "mul">;
-def GPU_AllReduceOpOr : I32EnumAttrCase<"OR", 5, "or">;
-def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 6, "xor">;
+def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 1, "mul">;
+def GPU_AllReduceOpMinUI : I32EnumAttrCase<"MINUI", 2, "minui">;
+def GPU_AllReduceOpMinSI : I32EnumAttrCase<"MINSI", 3, "minsi">;
+// Follows the `arith.minnumf` semantics.
+def GPU_AllReduceOpMinF : I32EnumAttrCase<"MINF", 4, "minf">;
+def GPU_AllReduceOpMaxUI : I32EnumAttrCase<"MAXUI", 5, "maxui">;
+def GPU_AllReduceOpMaxSI : I32EnumAttrCase<"MAXSI", 6, "maxsi">;
+// Follows the `arith.maxnumf` semantics.
+def GPU_AllReduceOpMaxF : I32EnumAttrCase<"MAXF", 7, "maxf">;
+def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 8, "and">;
+def GPU_AllReduceOpOr : I32EnumAttrCase<"OR", 9, "or">;
+def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 10, "xor">;
+// Follows the `arith.minimumf` semantics.
+def GPU_AllReduceOpMinimumF : I32EnumAttrCase<"MINIMUMF", 11, "minimumf">;
+// Follows the `arith.maximumf` semantics.
+def GPU_AllReduceOpMaximumF : I32EnumAttrCase<"MAXIMUMF", 12, "maximumf">;
def GPU_AllReduceOperation : I32EnumAttr<"AllReduceOperation",
"built-in reduction operations supported by gpu.allreduce.",
[
GPU_AllReduceOpAdd,
- GPU_AllReduceOpAnd,
- GPU_AllReduceOpMax,
- GPU_AllReduceOpMin,
GPU_AllReduceOpMul,
+ GPU_AllReduceOpMinUI,
+ GPU_AllReduceOpMinSI,
+ GPU_AllReduceOpMinF,
+ GPU_AllReduceOpMaxUI,
+ GPU_AllReduceOpMaxSI,
+ GPU_AllReduceOpMaxF,
+ GPU_AllReduceOpAnd,
GPU_AllReduceOpOr,
- GPU_AllReduceOpXor
+ GPU_AllReduceOpXor,
+ GPU_AllReduceOpMinimumF,
+ GPU_AllReduceOpMaximumF
]>{
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::gpu";
@@ -918,8 +934,11 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
compute the sum of each work item's %0 value. The first version specifies
the accumulation as operation, whereas the second version specifies the
- accumulation as code region. The accumulation operation must be one of:
- `add`, `and`, `max`, `min`, `mul`, `or`, `xor`.
+ accumulation as code region. The reduction operation must be one of:
+ * Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
+ `or`, `xor`
+ * Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
+ `maximumf`
If `uniform` flag is set either none or all work items of a workgroup
need to execute this op in convergence.
@@ -951,7 +970,12 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
```
If `uniform` flag is set either none or all work items of a subgroup
- need to execute this op in convergence.
+ need to execute this op in convergence. The reduction operation must be one
+ of:
+ * Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
+ `or`, `xor`
+ * Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
+ `maximumf`
}];
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
(`uniform` $uniform^)? attr-dict
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 86a77f557cb9579..d49f7adaee9ba74 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -65,17 +65,28 @@ convertReduxKind(gpu::AllReduceOperation mode) {
switch (mode) {
case gpu::AllReduceOperation::ADD:
return NVVM::ReduxKind::ADD;
+ case gpu::AllReduceOperation::MUL:
+ return std::nullopt;
+ case gpu::AllReduceOperation::MINSI:
+ return NVVM::ReduxKind::MIN;
+ case gpu::AllReduceOperation::MINUI:
+ return std::nullopt;
+ case gpu::AllReduceOperation::MINF:
+ return NVVM::ReduxKind::MIN;
+ case gpu::AllReduceOperation::MAXSI:
+ return NVVM::ReduxKind::MAX;
+ case gpu::AllReduceOperation::MAXUI:
+ return std::nullopt;
+ case gpu::AllReduceOperation::MAXF:
+ return NVVM::ReduxKind::MAX;
case gpu::AllReduceOperation::AND:
return NVVM::ReduxKind::AND;
- case gpu::AllReduceOperation::MAX:
- return NVVM::ReduxKind::MAX;
- case gpu::AllReduceOperation::MIN:
- return NVVM::ReduxKind::MIN;
case gpu::AllReduceOperation::OR:
return NVVM::ReduxKind::OR;
case gpu::AllReduceOperation::XOR:
return NVVM::ReduxKind::XOR;
- case gpu::AllReduceOperation::MUL:
+ case gpu::AllReduceOperation::MINIMUMF:
+ case gpu::AllReduceOperation::MAXIMUMF:
return std::nullopt;
}
return std::nullopt;
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 693cc3f6236b574..272ebb8e357b40c 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -503,26 +503,52 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
return std::nullopt;
}
+ // TODO: The SPIR-V spec does not specify how -0.0 / +0.0 and NaN values are
+ // handled in *FMin/*FMax reduction ops. We should double account for this not
+ // being defined in this conversion.
+
using ReduceType = gpu::AllReduceOperation;
- namespace spv = spirv;
const OpHandler handlers[] = {
{ReduceType::ADD,
- &createGroupReduceOpImpl<spv::GroupIAddOp, spv::GroupNonUniformIAddOp>,
- &createGroupReduceOpImpl<spv::GroupFAddOp, spv::GroupNonUniformFAddOp>},
+ &createGroupReduceOpImpl<spirv::GroupIAddOp,
+ spirv::GroupNonUniformIAddOp>,
+ &createGroupReduceOpImpl<spirv::GroupFAddOp,
+ spirv::GroupNonUniformFAddOp>},
{ReduceType::MUL,
- &createGroupReduceOpImpl<spv::GroupIMulKHROp,
- spv::GroupNonUniformIMulOp>,
- &createGroupReduceOpImpl<spv::GroupFMulKHROp,
- spv::GroupNonUniformFMulOp>},
- {ReduceType::MIN,
- &createGroupReduceOpImpl<spv::GroupSMinOp, spv::GroupNonUniformSMinOp>,
- &createGroupReduceOpImpl<spv::GroupFMinOp, spv::GroupNonUniformFMinOp>},
- {ReduceType::MAX,
- &createGroupReduceOpImpl<spv::GroupSMaxOp, spv::GroupNonUniformSMaxOp>,
- &createGroupReduceOpImpl<spv::GroupFMaxOp, spv::GroupNonUniformFMaxOp>},
- };
-
- for (auto &handler : handlers)
+ &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
+ spirv::GroupNonUniformIMulOp>,
+ &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
+ spirv::GroupNonUniformFMulOp>},
+ {ReduceType::MINUI,
+ &createGroupReduceOpImpl<spirv::GroupUMinOp,
+ spirv::GroupNonUniformUMinOp>,
+ nullptr},
+ {ReduceType::MINSI,
+ &createGroupReduceOpImpl<spirv::GroupSMinOp,
+ spirv::GroupNonUniformSMinOp>,
+ nullptr},
+ {ReduceType::MINF, nullptr,
+ &createGroupReduceOpImpl<spirv::GroupFMinOp,
+ spirv::GroupNonUniformFMinOp>},
+ {ReduceType::MAXUI,
+ &createGroupReduceOpImpl<spirv::GroupUMaxOp,
+ spirv::GroupNonUniformUMaxOp>,
+ nullptr},
+ {ReduceType::MAXSI,
+ &createGroupReduceOpImpl<spirv::GroupSMaxOp,
+ spirv::GroupNonUniformSMaxOp>,
+ nullptr},
+ {ReduceType::MAXF, nullptr,
+ &createGroupReduceOpImpl<spirv::GroupFMaxOp,
+ spirv::GroupNonUniformFMaxOp>},
+ {ReduceType::MINIMUMF, nullptr,
+ &createGroupReduceOpImpl<spirv::GroupFMinOp,
+ spirv::GroupNonUniformFMinOp>},
+ {ReduceType::MAXIMUMF, nullptr,
+ &createGroupReduceOpImpl<spirv::GroupFMaxOp,
+ spirv::GroupNonUniformFMaxOp>}};
+
+ for (const OpHandler &handler : handlers)
if (handler.type == opType)
return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 9517c053c8360ef..49a8d96e83c5a0b 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -27,7 +27,9 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
@@ -485,12 +487,23 @@ static LogicalResult verifyAttributions(Operation *op,
// AllReduceOp
//===----------------------------------------------------------------------===//
-static bool verifyReduceOpAndType(gpu::AllReduceOperation opName,
- Type resType) {
- return (opName != gpu::AllReduceOperation::AND &&
- opName != gpu::AllReduceOperation::OR &&
- opName != gpu::AllReduceOperation::XOR) ||
- llvm::isa<IntegerType>(resType);
+static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
+ Type resType) {
+ using Kind = gpu::AllReduceOperation;
+ if (llvm::is_contained(
+ {Kind::MINF, Kind::MAXF, Kind::MINIMUMF, Kind::MAXIMUMF}, opName)) {
+ if (!isa<FloatType>(resType))
+ return failure();
+ }
+
+ if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
+ Kind::AND, Kind::OR, Kind::XOR},
+ opName)) {
+ if (!isa<IntegerType>(resType))
+ return failure();
+ }
+
+ return success();
}
LogicalResult gpu::AllReduceOp::verifyRegions() {
@@ -517,12 +530,13 @@ LogicalResult gpu::AllReduceOp::verifyRegions() {
return emitError("expected gpu.yield op in region");
} else {
gpu::AllReduceOperation opName = *getOp();
- if (!verifyReduceOpAndType(opName, getType())) {
- return emitError()
- << '`' << gpu::stringifyAllReduceOperation(opName)
- << "` accumulator is only compatible with Integer type";
+ if (failed(verifyReduceOpAndType(opName, getType()))) {
+ return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
+ << "` reduction operation is not compatible with type "
+ << getType();
}
}
+
return success();
}
@@ -573,9 +587,10 @@ static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
LogicalResult gpu::SubgroupReduceOp::verify() {
gpu::AllReduceOperation opName = getOp();
- if (!verifyReduceOpAndType(opName, getType())) {
+ if (failed(verifyReduceOpAndType(opName, getType()))) {
return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
- << "` accumulator is only compatible with Integer type";
+ << "` reduction operation is not compatible with type "
+ << getType();
}
return success();
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index acf4f6d0e3d6979..ecee9a7b45e32bd 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -214,32 +214,37 @@ struct GpuAllReduceRewriter {
/// Returns an accumulator factory that creates an op specified by opName.
AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
+ using Kind = gpu::AllReduceOperation;
bool isFloatingPoint = isa<FloatType>(valueType);
switch (opName) {
- case gpu::AllReduceOperation::ADD:
+ case Kind::ADD:
return isFloatingPoint ? getFactory<arith::AddFOp>()
: getFactory<arith::AddIOp>();
- case gpu::AllReduceOperation::MUL:
+ case Kind::MUL:
return isFloatingPoint ? getFactory<arith::MulFOp>()
: getFactory<arith::MulIOp>();
- case gpu::AllReduceOperation::AND:
+ case Kind::MINSI:
+ return getFactory<arith::MinSIOp>();
+ case Kind::MINUI:
+ return getFactory<arith::MinUIOp>();
+ case Kind::MINF:
+ return getFactory<arith::MinNumFOp>();
+ case Kind::MAXSI:
+ return getFactory<arith::MaxSIOp>();
+ case Kind::MAXUI:
+ return getFactory<arith::MaxUIOp>();
+ case Kind::MAXF:
+ return getFactory<arith::MaxNumFOp>();
+ case Kind::AND:
return getFactory<arith::AndIOp>();
- case gpu::AllReduceOperation::OR:
+ case Kind::OR:
return getFactory<arith::OrIOp>();
- case gpu::AllReduceOperation::XOR:
+ case Kind::XOR:
return getFactory<arith::XOrIOp>();
- case gpu::AllReduceOperation::MAX:
- return isFloatingPoint
- ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
- arith::CmpFPredicate::UGT>()
- : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
- arith::CmpIPredicate::ugt>();
- case gpu::AllReduceOperation::MIN:
- return isFloatingPoint
- ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
- arith::CmpFPredicate::ULT>()
- : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
- arith::CmpIPredicate::ult>();
+ case Kind::MINIMUMF:
+ return getFactory<arith::MinimumFOp>();
+ case Kind::MAXIMUMF:
+ return getFactory<arith::MaximumFOp>();
}
llvm_unreachable("unknown GPU AllReduceOperation");
}
@@ -247,21 +252,11 @@ struct GpuAllReduceRewriter {
/// Returns an accumulator factory that creates an op of type T.
template <typename T>
AccumulatorFactory getFactory() {
- return [&](Value lhs, Value rhs) {
+ return [this](Value lhs, Value rhs) {
return create<T>(lhs.getType(), lhs, rhs);
};
}
- /// Returns an accumulator for comparison such as min, max. T is the type
- /// of the compare op.
- template <typename T, typename PredicateEnum, PredicateEnum predicate>
- AccumulatorFactory getCmpFactory() const {
- return [&](Value lhs, Value rhs) {
- Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
- return rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
- };
- }
-
/// Creates an if-block skeleton and calls the two factories to generate the
/// ops in the `then` and `else` block..
///
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index c18bb423a6e6001..20a200e812c1259 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -582,22 +582,22 @@ gpu.module @test_module_30 {
%result = gpu.subgroup_reduce add %arg0 uniform {} : (i32) -> (i32)
gpu.return
}
- // CHECK-LABEL: func @subgroup_reduce_and
- gpu.func @subgroup_reduce_and(%arg0 : i32) {
- // CHECK: nvvm.redux.sync and {{.*}}
- %result = gpu.subgroup_reduce and %arg0 uniform {} : (i32) -> (i32)
+ // CHECK-LABEL: @subgroup_reduce_minsi
+ gpu.func @subgroup_reduce_minsi(%arg0 : i32) {
+ // CHECK: nvvm.redux.sync min {{.*}}
+ %result = gpu.subgroup_reduce minsi %arg0 uniform {} : (i32) -> (i32)
gpu.return
}
- // CHECK-LABEL: @subgroup_reduce_max
- gpu.func @subgroup_reduce_max(%arg0 : i32) {
+ // CHECK-LABEL: @subgroup_reduce_maxsi
+ gpu.func @subgroup_reduce_maxsi(%arg0 : i32) {
// CHECK: nvvm.redux.sync max {{.*}}
- %result = gpu.subgroup_reduce max %arg0 uniform {} : (i32) -> (i32)
+ %result = gpu.subgroup_reduce maxsi %arg0 uniform {} : (i32) -> (i32)
gpu.return
}
- // CHECK-LABEL: @subgroup_reduce_min
- gpu.func @subgroup_reduce_min(%arg0 : i32) {
- // CHECK: nvvm.redux.sync min {{.*}}
- %result = gpu.subgroup_reduce min %arg0 uniform {} : (i32) -> (i32)
+ // CHECK-LABEL: func @subgroup_reduce_and
+ gpu.func @subgroup_reduce_and(%arg0 : i32) {
+ // CHECK: nvvm.redux.sync and {{.*}}
+ %result = gpu.subgroup_reduce and %arg0 uniform {} : (i32) -> (i32)
gpu.return
}
// CHECK-LABEL: @subgroup_reduce_or
diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index 1e5d64387650ce3..feb7ee185a8a858 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -331,7 +331,7 @@ gpu.module @kernels {
gpu.func @test(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupFMin <Workgroup> <Reduce> %[[ARG]] : f32
- %reduced = gpu.all_reduce min %arg uniform {} : (f32) -> (f32)
+ %reduced = gpu.all_reduce minf %arg uniform {} : (f32) -> (f32)
gpu.return
}
}
@@ -351,7 +351,7 @@ gpu.module @kernels {
gpu.func @test(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupNonUniformFMin "Workgroup" "Reduce" %[[ARG]] : f32
- %reduced = gpu.all_reduce min %arg {} : (f32) -> (f32)
+ %reduced = gpu.all_reduce minf %arg {} : (f32) -> (f32)
gpu.return
}
}
@@ -371,7 +371,9 @@ gpu.module @kernels {
gpu.func @test(%arg : i32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupSMin <Workgroup> <Reduce> %[[ARG]] : i32
- %reduced = gpu.all_reduce min %arg uniform {} : (i32) -> (i32)
+ // CHECK: %{{.*}} = spirv.GroupUMin <Workgroup> <Reduce> %[[ARG]] : i32
+ %r0 = gpu.all_reduce minsi %arg uniform {} : (i32) -> (i32)
+ %r1 = gpu.all_reduce minui %arg uniform {} : (i32) -> (i32)
gpu.return
}
}
@@ -390,8 +392,9 @@ gpu.module @kernels {
// CHECK-SAME: (%[[ARG:.*]]: i32)
gpu.func @test(%arg : i32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
- // CHECK: %{{.*}} = spirv.GroupNonUniformSMin "Workgroup" "Reduce" %[[ARG]] : i32
- %reduced = gpu.all_reduce min %arg {} : (i32) -> (i32)
+ // CHECK: %{{.*}} = spirv.GroupNonUniformUMin "Workgroup" "Reduce" %[[ARG]] : i32
+ %r0 = gpu.all_reduce minsi %arg {} : (i32) -> (i32)
+ %r1 = gpu.all_reduce minui %arg {} : (i32) -> (i32)
gpu.return
}
}
@@ -411,7 +414,7 @@ gpu.module @kernels {
gpu.func @test(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupFMin <Subgroup> <Reduce> %[[ARG]] : f32
- %reduced = gpu.subgroup_reduce min %arg uniform : (f32) -> (f32)
+ %reduced = gpu.subgroup_reduce minf %arg uniform : (f32) -> (f32)
gpu.return
}
}
@@ -431,7 +434,7 @@ gpu.module @kernels {
gpu.func @test(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupNonUniformFMin "Subgroup" "Reduce" %[[ARG]] : f32
- %reduced = gpu.subgroup_reduce min %arg : (f32) -> (f32)
+ %reduced = gpu.subgroup_reduce minf %arg : (f32) -> (f32)
gpu.return
}
}
@@ -451,7 +454,9 @@ gpu.module @kernels {
gpu.func @test(%arg : i32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupSMin <Subgroup> <Reduce> %[[ARG]] : i32
- %reduced = gpu.subgroup_reduce min %arg uniform : (i32) -> (i32)
+ // CHECK: %{{.*}} = spirv.G...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/73423
More information about the Mlir-commits
mailing list