[Mlir-commits] [mlir] c9c6017 - [mlir][spirv] Implement lowering `gpu.subgroup_reduce` with cluster size for SPIRV (#141402)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 6 09:50:21 PDT 2025


Author: Darren Wihandi
Date: 2025-06-06T12:50:18-04:00
New Revision: c9c60172a187eab07ab6ac4168862862074e6721

URL: https://github.com/llvm/llvm-project/commit/c9c60172a187eab07ab6ac4168862862074e6721
DIFF: https://github.com/llvm/llvm-project/commit/c9c60172a187eab07ab6ac4168862862074e6721.diff

LOG: [mlir][spirv] Implement lowering `gpu.subgroup_reduce` with cluster size for SPIRV (#141402)

Implement lowering of `gpu.subgroup_reduce` with a cluster size
attribute to SPIRV by using the `ClusteredReduce` group operation.

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
    mlir/test/Conversion/GPUToSPIRV/reductions.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 78e6ebb523a46..46db5d3fdca3b 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -464,27 +464,39 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
 
 template <typename UniformOp, typename NonUniformOp>
 static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
-                                     Value arg, bool isGroup, bool isUniform) {
+                                     Value arg, bool isGroup, bool isUniform,
+                                     std::optional<uint32_t> clusterSize) {
   Type type = arg.getType();
   auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
                                            isGroup ? spirv::Scope::Workgroup
                                                    : spirv::Scope::Subgroup);
-  auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(),
-                                                spirv::GroupOperation::Reduce);
+  auto groupOp = spirv::GroupOperationAttr::get(
+      builder.getContext(), clusterSize.has_value()
+                                ? spirv::GroupOperation::ClusteredReduce
+                                : spirv::GroupOperation::Reduce);
   if (isUniform) {
     return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
         .getResult();
   }
-  return builder.create<NonUniformOp>(loc, type, scope, groupOp, arg, Value{})
+
+  Value clusterSizeValue;
+  if (clusterSize.has_value())
+    clusterSizeValue = builder.create<spirv::ConstantOp>(
+        loc, builder.getI32Type(),
+        builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
+
+  return builder
+      .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
       .getResult();
 }
 
-static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
-                                                Location loc, Value arg,
-                                                gpu::AllReduceOperation opType,
-                                                bool isGroup, bool isUniform) {
+static std::optional<Value>
+createGroupReduceOp(OpBuilder &builder, Location loc, Value arg,
+                    gpu::AllReduceOperation opType, bool isGroup,
+                    bool isUniform, std::optional<uint32_t> clusterSize) {
   enum class ElemType { Float, Boolean, Integer };
-  using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
+  using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool,
+                          std::optional<uint32_t>);
   struct OpHandler {
     gpu::AllReduceOperation kind;
     ElemType elemType;
@@ -548,7 +560,7 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
 
   for (const OpHandler &handler : handlers)
     if (handler.kind == opType && elementType == handler.elemType)
-      return handler.func(builder, loc, arg, isGroup, isUniform);
+      return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
 
   return std::nullopt;
 }
@@ -571,7 +583,7 @@ class GPUAllReduceConversion final
 
     auto result =
         createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
-                            /*isGroup*/ true, op.getUniform());
+                            /*isGroup*/ true, op.getUniform(), std::nullopt);
     if (!result)
       return failure();
 
@@ -589,16 +601,17 @@ class GPUSubgroupReduceConversion final
   LogicalResult
   matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (op.getClusterSize())
+    if (op.getClusterStride() > 1) {
       return rewriter.notifyMatchFailure(
-          op, "lowering for clustered reduce not implemented");
+          op, "lowering for cluster stride > 1 is not implemented");
+    }
 
     if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
       return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
 
-    auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
-                                      adaptor.getOp(),
-                                      /*isGroup=*/false, adaptor.getUniform());
+    auto result = createGroupReduceOp(
+        rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(),
+        /*isGroup=*/false, adaptor.getUniform(), op.getClusterSize());
     if (!result)
       return failure();
 

diff  --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index ae834b9915d50..e7e0fa296c98a 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -789,3 +789,44 @@ gpu.module @kernels {
   }
 }
 }
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupUniformArithmeticKHR, GroupNonUniformClustered], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test_subgroup_reduce_clustered
+  //  CHECK-SAME: (%[[ARG:.*]]: f32)
+  //  CHECK: %[[CLUSTER_SIZE:.*]] = spirv.Constant 8 : i32
+  gpu.func @test_subgroup_reduce_clustered(%arg : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupNonUniformFAdd <Subgroup> <ClusteredReduce> %[[ARG]] cluster_size(%[[CLUSTER_SIZE]]) : f32, i32 -> f32
+    %reduced = gpu.subgroup_reduce add %arg cluster(size = 8) : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+// Subgrop reduce with cluster stride > 1 is not yet supported.
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupUniformArithmeticKHR, GroupNonUniformClustered], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  gpu.func @test_invalid_subgroup_reduce_clustered_stride(%arg : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %reduced = gpu.subgroup_reduce add %arg cluster(size = 8, stride = 2) : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}


        


More information about the Mlir-commits mailing list