[Mlir-commits] [mlir] [mlir][spirv][gpu] Add lowering for `gpu.subgroup_broadcast` (PR #185818)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 10 23:52:43 PDT 2026
https://github.com/hankluo6 updated https://github.com/llvm/llvm-project/pull/185818
>From cb0d77ecec0cebbe8986366e470d04fa8a05285f Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sat, 7 Mar 2026 20:01:20 -0800
Subject: [PATCH 1/2] Add lowering for gpu.subgroup_broadcast
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 4 +-
.../Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 49 +++++++++++++++++++
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 39 +++++++++++++++
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp | 12 +++++
4 files changed, 103 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2f189c64300ae..0e489dfc8386d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4587,6 +4587,7 @@ def SPIRV_OC_OpGroupNonUniformAll : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformAny : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
def SPIRV_OC_OpGroupNonUniformAllEqual : I32EnumAttrCase<"OpGroupNonUniformAllEqual", 336>;
def SPIRV_OC_OpGroupNonUniformBroadcast : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>;
+def SPIRV_OC_OpGroupNonUniformBroadcastFirst : I32EnumAttrCase<"OpGroupNonUniformBroadcastFirst", 338>;
def SPIRV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
def SPIRV_OC_OpGroupNonUniformBallotBitCount : I32EnumAttrCase<"OpGroupNonUniformBallotBitCount", 342>;
def SPIRV_OC_OpGroupNonUniformBallotFindLSB : I32EnumAttrCase<"OpGroupNonUniformBallotFindLSB", 343>;
@@ -4725,7 +4726,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
- SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBallot,
+ SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBroadcastFirst,
+ SPIRV_OC_OpGroupNonUniformBallot,
SPIRV_OC_OpGroupNonUniformBallotBitCount,
SPIRV_OC_OpGroupNonUniformBallotFindLSB,
SPIRV_OC_OpGroupNonUniformBallotFindMSB, SPIRV_OC_OpGroupNonUniformShuffle,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 784eb40141b74..a8bdb4256ed1b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -269,6 +269,55 @@ def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast",
// -----
+def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst",
+ [Pure, AllTypesMatch<["value", "result"]>]> {
+ let summary = [{
+ Result is the Value of the invocation from the active invocations with
+ the lowest id in the group to all active invocations in the group.
+ }];
+
+ let description = [{
+ Result Type must be a scalar or vector of floating-point type, integer
+ type, or Boolean type.
+
+ Execution must be Workgroup or Subgroup Scope.
+
+ The type of Value must be the same as Result Type.
+
+ #### Example:
+
+ ```mlir
+ %scalar_value = ... : f32
+ %vector_value = ... : vector<4xf32>
+ %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %scalar_value : f32
+ %1 = spirv.GroupNonUniformBroadcast <Workgroup> %vector_value, %id :
+ vector<4xf32>, i32
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_3>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[]>,
+ Capability<[SPIRV_C_GroupNonUniformBallot]>
+ ];
+
+ let arguments = (ins
+ SPIRV_ScopeAttr:$execution_scope,
+ SPIRV_Type:$value
+ );
+
+ let results = (outs
+ SPIRV_Type:$result
+ );
+
+ let assemblyFormat = [{
+ $execution_scope operands attr-dict `:` type($value)
+ }];
+}
+
+// -----
+
def SPIRV_GroupNonUniformElectOp : SPIRV_Op<"GroupNonUniformElect", []> {
let summary = [{
Result is true only in the active invocation with the lowest id in the
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index c33a903d03393..d5269511d61a2 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -130,6 +130,18 @@ class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Pattern to convert a gpu.subgroup_broadcast op into a
+/// spirv.GroupNonUniformBroadcast op.
+class GPUSubgroupBroadcastConversion final
+ : public OpConversionPattern<gpu::SubgroupBroadcastOp> {
+public:
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
public:
using Base::Base;
@@ -542,6 +554,32 @@ LogicalResult GPURotateConversion::matchAndRewrite(
return success();
}
+//===----------------------------------------------------------------------===//
+// Subgroup broadcast
+//===----------------------------------------------------------------------===//
+
+LogicalResult GPUSubgroupBroadcastConversion::matchAndRewrite(
+ gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+ Value result;
+
+ switch (op.getBroadcastType()) {
+ case gpu::BroadcastType::specific_lane:
+ result = spirv::GroupNonUniformBroadcastOp::create(
+ rewriter, loc, scope, adaptor.getSrc(), adaptor.getLane());
+ break;
+ case gpu::BroadcastType::first_active_lane:
+ result = spirv::GroupNonUniformBroadcastFirstOp::create(
+ rewriter, loc, scope, adaptor.getSrc());
+ break;
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Group ops
//===----------------------------------------------------------------------===//
@@ -832,6 +870,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
patterns.add<
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
+ GPUSubgroupBroadcastConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index a1bb7f89e9183..f1a216504a7e4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -135,6 +135,18 @@ LogicalResult GroupNonUniformBroadcastOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformBroadcastFirstOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformBroadcastFirstOp::verify() {
+ spirv::Scope scope = getExecutionScope();
+ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+ return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformShuffle*
//===----------------------------------------------------------------------===//
>From dae5bcf304ac7c9e0fbef34186430c93efe32f8a Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sat, 7 Mar 2026 22:54:21 -0800
Subject: [PATCH 2/2] Add test
---
.../Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 4 +--
.../Dialect/SPIRV/IR/non-uniform-ops.mlir | 28 +++++++++++++++++++
2 files changed, 30 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index a8bdb4256ed1b..1ded87fc2c090 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -290,8 +290,8 @@ def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFi
%scalar_value = ... : f32
%vector_value = ... : vector<4xf32>
%0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %scalar_value : f32
- %1 = spirv.GroupNonUniformBroadcast <Workgroup> %vector_value, %id :
- vector<4xf32>, i32
+ %1 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %vector_value :
+ vector<4xf32>, i32 i32
```
}];
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 5383f7656a1be..2301379240a1a 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -123,6 +123,34 @@ func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid:
// -----
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformBroadcastFirst
+//===----------------------------------------------------------------------===//
+
+func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
+ // CHECK: spirv.GroupNonUniformBroadcastFirst <Workgroup> %{{.*}} : f32
+ %0 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %value : f32
+ return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_broadcast_first_scalar(%value: f32) -> f32 {
+ // CHECK: spirv.GroupNonUniformBroadcastFirst <Subgroup> %{{.*}} : f32
+ %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %value : f32
+ return %0 : f32
+}
+
+// -----
+
+func.func @group_non_uniform_broadcast_first_negative_scope(%value: f32) -> f32 {
+ // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ %0 = spirv.GroupNonUniformBroadcastFirst <Device> %value : f32
+ return %0 : f32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformElect
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list