[Mlir-commits] [mlir] [mlir][spirv][gpu] Add lowering for `gpu.subgroup_broadcast` (PR #185818)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 10 23:46:02 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Hank (hankluo6)

<details>
<summary>Changes</summary>

Fixes #<!-- -->157940 

Add `spirv.GroupNonUniformBroadcastFirst` and lower `gpu.subgroup_broadcast` / `gpu.subgroup_broadcast_first` to SPIR-V non-uniform broadcast ops.

---
Full diff: https://github.com/llvm/llvm-project/pull/185818.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+3-1) 
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td (+49) 
- (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp (+39) 
- (modified) mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp (+12) 
- (modified) mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir (+28) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2f189c64300ae..0e489dfc8386d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4587,6 +4587,7 @@ def SPIRV_OC_OpGroupNonUniformAll             : I32EnumAttrCase<"OpGroupNonUnifo
 def SPIRV_OC_OpGroupNonUniformAny             : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
 def SPIRV_OC_OpGroupNonUniformAllEqual        : I32EnumAttrCase<"OpGroupNonUniformAllEqual", 336>;
 def SPIRV_OC_OpGroupNonUniformBroadcast       : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>;
+def SPIRV_OC_OpGroupNonUniformBroadcastFirst  : I32EnumAttrCase<"OpGroupNonUniformBroadcastFirst", 338>;
 def SPIRV_OC_OpGroupNonUniformBallot          : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
 def SPIRV_OC_OpGroupNonUniformBallotBitCount  : I32EnumAttrCase<"OpGroupNonUniformBallotBitCount", 342>;
 def SPIRV_OC_OpGroupNonUniformBallotFindLSB   : I32EnumAttrCase<"OpGroupNonUniformBallotFindLSB", 343>;
@@ -4725,7 +4726,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
       SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
       SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
-      SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBallot,
+      SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBroadcastFirst,
+      SPIRV_OC_OpGroupNonUniformBallot,
       SPIRV_OC_OpGroupNonUniformBallotBitCount,
       SPIRV_OC_OpGroupNonUniformBallotFindLSB,
       SPIRV_OC_OpGroupNonUniformBallotFindMSB, SPIRV_OC_OpGroupNonUniformShuffle,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 784eb40141b74..1ded87fc2c090 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -269,6 +269,55 @@ def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast",
 
 // -----
 
+def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst",
+  [Pure, AllTypesMatch<["value", "result"]>]> {
+  let summary = [{
+    Result is the Value of the invocation from the active invocations with 
+    the lowest id in the group to all active invocations in the group.
+  }];
+
+  let description = [{
+    Result Type  must be a scalar or vector of floating-point type, integer
+    type, or Boolean type.
+
+    Execution must be Workgroup or Subgroup Scope.
+
+    The type of Value must be the same as Result Type.
+
+    #### Example:
+
+    ```mlir
+    %scalar_value = ... : f32
+    %vector_value = ... : vector<4xf32>
+    %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %scalar_value : f32
+    %1 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %vector_value :
+      vector<4xf32>, i32 i32
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_3>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[SPIRV_C_GroupNonUniformBallot]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScopeAttr:$execution_scope,
+    SPIRV_Type:$value
+  );
+
+  let results = (outs
+    SPIRV_Type:$result
+  );
+
+  let assemblyFormat = [{
+    $execution_scope operands attr-dict `:` type($value)
+  }];
+}
+
+// -----
+
 def SPIRV_GroupNonUniformElectOp : SPIRV_Op<"GroupNonUniformElect", []> {
   let summary = [{
     Result is true only in the active invocation with the lowest id in the
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index c33a903d03393..d5269511d61a2 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -130,6 +130,18 @@ class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Pattern to convert a gpu.subgroup_broadcast op into a
+/// spirv.GroupNonUniformBroadcast op.
+class GPUSubgroupBroadcastConversion final
+    : public OpConversionPattern<gpu::SubgroupBroadcastOp> {
+public:
+  using Base::Base;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
 public:
   using Base::Base;
@@ -542,6 +554,32 @@ LogicalResult GPURotateConversion::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Subgroup broadcast
+//===----------------------------------------------------------------------===//
+
+LogicalResult GPUSubgroupBroadcastConversion::matchAndRewrite(
+    gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+  Value result;
+
+  switch (op.getBroadcastType()) {
+  case gpu::BroadcastType::specific_lane:
+    result = spirv::GroupNonUniformBroadcastOp::create(
+        rewriter, loc, scope, adaptor.getSrc(), adaptor.getLane());
+    break;
+  case gpu::BroadcastType::first_active_lane:
+    result = spirv::GroupNonUniformBroadcastFirstOp::create(
+        rewriter, loc, scope, adaptor.getSrc());
+    break;
+  }
+
+  rewriter.replaceOp(op, result);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Group ops
 //===----------------------------------------------------------------------===//
@@ -832,6 +870,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
   patterns.add<
       GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
       GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
+      GPUSubgroupBroadcastConversion,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
       LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index a1bb7f89e9183..f1a216504a7e4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -135,6 +135,18 @@ LogicalResult GroupNonUniformBroadcastOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformBroadcastFirstOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformBroadcastFirstOp::verify() {
+  spirv::Scope scope = getExecutionScope();
+  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+    return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.GroupNonUniformShuffle*
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 5383f7656a1be..2301379240a1a 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -123,6 +123,34 @@ func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid:
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformBroadcastFirst
+//===----------------------------------------------------------------------===//
+
+func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
+  // CHECK: spirv.GroupNonUniformBroadcastFirst <Workgroup> %{{.*}} : f32
+  %0 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %value : f32
+  return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_broadcast_first_scalar(%value: f32) -> f32 {
+  // CHECK: spirv.GroupNonUniformBroadcastFirst <Subgroup> %{{.*}} : f32
+  %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %value : f32
+  return %0 : f32
+}
+
+// -----
+
+func.func @group_non_uniform_broadcast_first_negative_scope(%value: f32) -> f32 {
+  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  %0 = spirv.GroupNonUniformBroadcastFirst <Device> %value : f32
+  return %0 : f32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GroupNonUniformElect
 //===----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/185818


More information about the Mlir-commits mailing list