[Mlir-commits] [mlir] [mlir][spirv] Implement lowering `gpu.subgroup_reduce` with cluster size for SPIRV (PR #141402)
Darren Wihandi
llvmlistbot at llvm.org
Tue Jun 3 14:40:06 PDT 2025
https://github.com/fairywreath updated https://github.com/llvm/llvm-project/pull/141402
>From 377841b9a9b7192160ad4bc59be1a41276edc7b7 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Sun, 25 May 2025 03:31:17 -0400
Subject: [PATCH 1/3] [mlir][spirv] Implement lowering of `gpu.subgroup_reduce`
with cluster size for SPIRV
---
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 45 ++++++++++++-------
.../Conversion/GPUToSPIRV/reductions.mlir | 41 +++++++++++++++++
2 files changed, 70 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 3cc64b82950b5..f42605a6e8ce1 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -464,27 +464,39 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
template <typename UniformOp, typename NonUniformOp>
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
- Value arg, bool isGroup, bool isUniform) {
+ Value arg, bool isGroup, bool isUniform,
+ std::optional<uint32_t> clusterSize) {
Type type = arg.getType();
auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
isGroup ? spirv::Scope::Workgroup
: spirv::Scope::Subgroup);
- auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(),
- spirv::GroupOperation::Reduce);
+ auto groupOp = spirv::GroupOperationAttr::get(
+ builder.getContext(), clusterSize.has_value()
+ ? spirv::GroupOperation::ClusteredReduce
+ : spirv::GroupOperation::Reduce);
if (isUniform) {
return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
.getResult();
}
- return builder.create<NonUniformOp>(loc, type, scope, groupOp, arg, Value{})
+
+ Value clusterSizeValue =
+ clusterSize.has_value()
+ ? builder.create<spirv::ConstantOp>(
+ loc, builder.getI32Type(),
+ builder.getIntegerAttr(builder.getI32Type(), *clusterSize))
+ : Value{};
+ return builder
+ .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
.getResult();
}
-static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
- Location loc, Value arg,
- gpu::AllReduceOperation opType,
- bool isGroup, bool isUniform) {
+static std::optional<Value>
+createGroupReduceOp(OpBuilder &builder, Location loc, Value arg,
+ gpu::AllReduceOperation opType, bool isGroup,
+ bool isUniform, std::optional<uint32_t> clusterSize) {
enum class ElemType { Float, Boolean, Integer };
- using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
+ using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool,
+ std::optional<uint32_t>);
struct OpHandler {
gpu::AllReduceOperation kind;
ElemType elemType;
@@ -548,7 +560,7 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
for (const OpHandler &handler : handlers)
if (handler.kind == opType && elementType == handler.elemType)
- return handler.func(builder, loc, arg, isGroup, isUniform);
+ return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
return std::nullopt;
}
@@ -571,7 +583,7 @@ class GPUAllReduceConversion final
auto result =
createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
- /*isGroup*/ true, op.getUniform());
+ /*isGroup*/ true, op.getUniform(), std::nullopt);
if (!result)
return failure();
@@ -589,16 +601,17 @@ class GPUSubgroupReduceConversion final
LogicalResult
matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (op.getClusterSize())
+ if (op.getClusterStride() > 1) {
return rewriter.notifyMatchFailure(
- op, "lowering for clustered reduce not implemented");
+ op, "lowering for cluster stride > 1 is not implemented");
+ }
if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
- auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
- adaptor.getOp(),
- /*isGroup=*/false, adaptor.getUniform());
+ auto result = createGroupReduceOp(
+ rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(),
+ /*isGroup=*/false, adaptor.getUniform(), op.getClusterSize());
if (!result)
return failure();
diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index ae834b9915d50..08d9b094a5303 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -789,3 +789,44 @@ gpu.module @kernels {
}
}
}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupUniformArithmeticKHR, GroupNonUniformClustered], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+ // CHECK-LABEL: spirv.func @test
+ // CHECK-SAME: (%[[ARG:.*]]: f32)
+ // CHECK: %[[CLUSTER_SIZE:.*]] = spirv.Constant 8 : i32
+ gpu.func @test22(%arg : f32) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+ // CHECK: %{{.*}} = spirv.GroupNonUniformFAdd <Subgroup> <ClusteredReduce> %[[ARG]] cluster_size(%[[CLUSTER_SIZE]]) : f32, i32 -> f32
+ %reduced = gpu.subgroup_reduce add %arg cluster(size = 8) : (f32) -> (f32)
+ gpu.return
+ }
+}
+
+}
+
+// -----
+
+// Subgrop reduce with cluster stride > 1 is not yet supported.
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupUniformArithmeticKHR, GroupNonUniformClustered], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+ gpu.func @test22(%arg : f32) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+ // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+ %reduced = gpu.subgroup_reduce add %arg cluster(size = 8, stride = 2) : (f32) -> (f32)
+ gpu.return
+ }
+}
+
+}
>From 9dbd3d26f2b0e2074cf6c16521718b36d8306586 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Sun, 25 May 2025 20:22:42 -0400
Subject: [PATCH 2/3] Use better function names in test
---
mlir/test/Conversion/GPUToSPIRV/reductions.mlir | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index 08d9b094a5303..e7e0fa296c98a 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -798,10 +798,10 @@ module attributes {
} {
gpu.module @kernels {
- // CHECK-LABEL: spirv.func @test
+ // CHECK-LABEL: spirv.func @test_subgroup_reduce_clustered
// CHECK-SAME: (%[[ARG:.*]]: f32)
// CHECK: %[[CLUSTER_SIZE:.*]] = spirv.Constant 8 : i32
- gpu.func @test22(%arg : f32) kernel
+ gpu.func @test_subgroup_reduce_clustered(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// CHECK: %{{.*}} = spirv.GroupNonUniformFAdd <Subgroup> <ClusteredReduce> %[[ARG]] cluster_size(%[[CLUSTER_SIZE]]) : f32, i32 -> f32
%reduced = gpu.subgroup_reduce add %arg cluster(size = 8) : (f32) -> (f32)
@@ -821,7 +821,7 @@ module attributes {
} {
gpu.module @kernels {
- gpu.func @test22(%arg : f32) kernel
+ gpu.func @test_invalid_subgroup_reduce_clustered_stride(%arg : f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
// expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
%reduced = gpu.subgroup_reduce add %arg cluster(size = 8, stride = 2) : (f32) -> (f32)
>From a865235355de587f8daff8f6411e6c936fb69fdd Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Tue, 3 Jun 2025 15:39:50 -0600
Subject: [PATCH 3/3] Use if statement instead of ternary
---
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index f42605a6e8ce1..67c82f4b9653a 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -479,12 +479,12 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
.getResult();
}
- Value clusterSizeValue =
- clusterSize.has_value()
- ? builder.create<spirv::ConstantOp>(
- loc, builder.getI32Type(),
- builder.getIntegerAttr(builder.getI32Type(), *clusterSize))
- : Value{};
+ Value clusterSizeValue = {};
+ if (clusterSize.has_value())
+ clusterSizeValue = builder.create<spirv::ConstantOp>(
+ loc, builder.getI32Type(),
+ builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
+
return builder
.create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
.getResult();
More information about the Mlir-commits
mailing list