[Mlir-commits] [mlir] [mlir][spirv] Implement lowering `gpu.subgroup_reduce` with cluster size for SPIRV (PR #141402)

Darren Wihandi llvmlistbot at llvm.org
Tue Jun 3 14:40:06 PDT 2025


https://github.com/fairywreath updated https://github.com/llvm/llvm-project/pull/141402

>From 377841b9a9b7192160ad4bc59be1a41276edc7b7 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Sun, 25 May 2025 03:31:17 -0400
Subject: [PATCH 1/3] [mlir][spirv] Implement lowering of `gpu.subgroup_reduce`
 with cluster size for SPIRV

---
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 45 ++++++++++++-------
 .../Conversion/GPUToSPIRV/reductions.mlir     | 41 +++++++++++++++++
 2 files changed, 70 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 3cc64b82950b5..f42605a6e8ce1 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -464,27 +464,39 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
 
 template <typename UniformOp, typename NonUniformOp>
 static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
-                                     Value arg, bool isGroup, bool isUniform) {
+                                     Value arg, bool isGroup, bool isUniform,
+                                     std::optional<uint32_t> clusterSize) {
   Type type = arg.getType();
   auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
                                            isGroup ? spirv::Scope::Workgroup
                                                    : spirv::Scope::Subgroup);
-  auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(),
-                                                spirv::GroupOperation::Reduce);
+  auto groupOp = spirv::GroupOperationAttr::get(
+      builder.getContext(), clusterSize.has_value()
+                                ? spirv::GroupOperation::ClusteredReduce
+                                : spirv::GroupOperation::Reduce);
   if (isUniform) {
     return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
         .getResult();
   }
-  return builder.create<NonUniformOp>(loc, type, scope, groupOp, arg, Value{})
+
+  Value clusterSizeValue =
+      clusterSize.has_value()
+          ? builder.create<spirv::ConstantOp>(
+                loc, builder.getI32Type(),
+                builder.getIntegerAttr(builder.getI32Type(), *clusterSize))
+          : Value{};
+  return builder
+      .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
       .getResult();
 }
 
-static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
-                                                Location loc, Value arg,
-                                                gpu::AllReduceOperation opType,
-                                                bool isGroup, bool isUniform) {
+static std::optional<Value>
+createGroupReduceOp(OpBuilder &builder, Location loc, Value arg,
+                    gpu::AllReduceOperation opType, bool isGroup,
+                    bool isUniform, std::optional<uint32_t> clusterSize) {
   enum class ElemType { Float, Boolean, Integer };
-  using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
+  using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool,
+                          std::optional<uint32_t>);
   struct OpHandler {
     gpu::AllReduceOperation kind;
     ElemType elemType;
@@ -548,7 +560,7 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
 
   for (const OpHandler &handler : handlers)
     if (handler.kind == opType && elementType == handler.elemType)
-      return handler.func(builder, loc, arg, isGroup, isUniform);
+      return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
 
   return std::nullopt;
 }
@@ -571,7 +583,7 @@ class GPUAllReduceConversion final
 
     auto result =
         createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
-                            /*isGroup*/ true, op.getUniform());
+                            /*isGroup*/ true, op.getUniform(), std::nullopt);
     if (!result)
       return failure();
 
@@ -589,16 +601,17 @@ class GPUSubgroupReduceConversion final
   LogicalResult
   matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (op.getClusterSize())
+    if (op.getClusterStride() > 1) {
       return rewriter.notifyMatchFailure(
-          op, "lowering for clustered reduce not implemented");
+          op, "lowering for cluster stride > 1 is not implemented");
+    }
 
     if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
       return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
 
-    auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
-                                      adaptor.getOp(),
-                                      /*isGroup=*/false, adaptor.getUniform());
+    auto result = createGroupReduceOp(
+        rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(),
+        /*isGroup=*/false, adaptor.getUniform(), op.getClusterSize());
     if (!result)
       return failure();
 
diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index ae834b9915d50..08d9b094a5303 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -789,3 +789,44 @@ gpu.module @kernels {
   }
 }
 }
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupUniformArithmeticKHR, GroupNonUniformClustered], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: f32)
+  //  CHECK: %[[CLUSTER_SIZE:.*]] = spirv.Constant 8 : i32
+  gpu.func @test22(%arg : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupNonUniformFAdd <Subgroup> <ClusteredReduce> %[[ARG]] cluster_size(%[[CLUSTER_SIZE]]) : f32, i32 -> f32
+    %reduced = gpu.subgroup_reduce add %arg cluster(size = 8) : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+// Subgrop reduce with cluster stride > 1 is not yet supported.
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupUniformArithmeticKHR, GroupNonUniformClustered], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  gpu.func @test22(%arg : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %reduced = gpu.subgroup_reduce add %arg cluster(size = 8, stride = 2) : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}

>From 9dbd3d26f2b0e2074cf6c16521718b36d8306586 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Sun, 25 May 2025 20:22:42 -0400
Subject: [PATCH 2/3] Use better function names in test

---
 mlir/test/Conversion/GPUToSPIRV/reductions.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index 08d9b094a5303..e7e0fa296c98a 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -798,10 +798,10 @@ module attributes {
 } {
 
 gpu.module @kernels {
-  // CHECK-LABEL:  spirv.func @test
+  // CHECK-LABEL:  spirv.func @test_subgroup_reduce_clustered
   //  CHECK-SAME: (%[[ARG:.*]]: f32)
   //  CHECK: %[[CLUSTER_SIZE:.*]] = spirv.Constant 8 : i32
-  gpu.func @test22(%arg : f32) kernel
+  gpu.func @test_subgroup_reduce_clustered(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // CHECK: %{{.*}} = spirv.GroupNonUniformFAdd <Subgroup> <ClusteredReduce> %[[ARG]] cluster_size(%[[CLUSTER_SIZE]]) : f32, i32 -> f32
     %reduced = gpu.subgroup_reduce add %arg cluster(size = 8) : (f32) -> (f32)
@@ -821,7 +821,7 @@ module attributes {
 } {
 
 gpu.module @kernels {
-  gpu.func @test22(%arg : f32) kernel
+  gpu.func @test_invalid_subgroup_reduce_clustered_stride(%arg : f32) kernel
     attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
     // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
     %reduced = gpu.subgroup_reduce add %arg cluster(size = 8, stride = 2) : (f32) -> (f32)

>From a865235355de587f8daff8f6411e6c936fb69fdd Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Tue, 3 Jun 2025 15:39:50 -0600
Subject: [PATCH 3/3] Use if statement instead of ternary

---
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index f42605a6e8ce1..67c82f4b9653a 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -479,12 +479,12 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
         .getResult();
   }
 
-  Value clusterSizeValue =
-      clusterSize.has_value()
-          ? builder.create<spirv::ConstantOp>(
-                loc, builder.getI32Type(),
-                builder.getIntegerAttr(builder.getI32Type(), *clusterSize))
-          : Value{};
+  Value clusterSizeValue = {};
+  if (clusterSize.has_value())
+    clusterSizeValue = builder.create<spirv::ConstantOp>(
+        loc, builder.getI32Type(),
+        builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
+
   return builder
       .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
       .getResult();



More information about the Mlir-commits mailing list