[Mlir-commits] [mlir] [mlir][spirv][gpu] Add lowering for `gpu.subgroup_broadcast` (PR #185818)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 10 23:46:02 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Hank (hankluo6)
<details>
<summary>Changes</summary>
Fixes #<!-- -->157940
Add `spirv.GroupNonUniformBroadcastFirst` and lower `gpu.subgroup_broadcast` / `gpu.subgroup_broadcast_first` to SPIR-V non-uniform broadcast ops.
---
Full diff: https://github.com/llvm/llvm-project/pull/185818.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+3-1)
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td (+49)
- (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp (+39)
- (modified) mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp (+12)
- (modified) mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir (+28)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2f189c64300ae..0e489dfc8386d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4587,6 +4587,7 @@ def SPIRV_OC_OpGroupNonUniformAll : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformAny : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
def SPIRV_OC_OpGroupNonUniformAllEqual : I32EnumAttrCase<"OpGroupNonUniformAllEqual", 336>;
def SPIRV_OC_OpGroupNonUniformBroadcast : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>;
+def SPIRV_OC_OpGroupNonUniformBroadcastFirst : I32EnumAttrCase<"OpGroupNonUniformBroadcastFirst", 338>;
def SPIRV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
def SPIRV_OC_OpGroupNonUniformBallotBitCount : I32EnumAttrCase<"OpGroupNonUniformBallotBitCount", 342>;
def SPIRV_OC_OpGroupNonUniformBallotFindLSB : I32EnumAttrCase<"OpGroupNonUniformBallotFindLSB", 343>;
@@ -4725,7 +4726,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
- SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBallot,
+ SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBroadcastFirst,
+ SPIRV_OC_OpGroupNonUniformBallot,
SPIRV_OC_OpGroupNonUniformBallotBitCount,
SPIRV_OC_OpGroupNonUniformBallotFindLSB,
SPIRV_OC_OpGroupNonUniformBallotFindMSB, SPIRV_OC_OpGroupNonUniformShuffle,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 784eb40141b74..1ded87fc2c090 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -269,6 +269,55 @@ def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast",
// -----
+def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst",
+ [Pure, AllTypesMatch<["value", "result"]>]> {
+ let summary = [{
+ Result is the Value of the invocation from the active invocations with
+ the lowest id in the group to all active invocations in the group.
+ }];
+
+ let description = [{
+ Result Type must be a scalar or vector of floating-point type, integer
+ type, or Boolean type.
+
+ Execution must be Workgroup or Subgroup Scope.
+
+ The type of Value must be the same as Result Type.
+
+ #### Example:
+
+ ```mlir
+ %scalar_value = ... : f32
+ %vector_value = ... : vector<4xf32>
+ %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %scalar_value : f32
+ %1 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %vector_value :
+ vector<4xf32>, i32 i32
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_3>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[]>,
+ Capability<[SPIRV_C_GroupNonUniformBallot]>
+ ];
+
+ let arguments = (ins
+ SPIRV_ScopeAttr:$execution_scope,
+ SPIRV_Type:$value
+ );
+
+ let results = (outs
+ SPIRV_Type:$result
+ );
+
+ let assemblyFormat = [{
+ $execution_scope operands attr-dict `:` type($value)
+ }];
+}
+
+// -----
+
def SPIRV_GroupNonUniformElectOp : SPIRV_Op<"GroupNonUniformElect", []> {
let summary = [{
Result is true only in the active invocation with the lowest id in the
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index c33a903d03393..d5269511d61a2 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -130,6 +130,18 @@ class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Pattern to convert a gpu.subgroup_broadcast op into a
+/// spirv.GroupNonUniformBroadcast op.
+class GPUSubgroupBroadcastConversion final
+ : public OpConversionPattern<gpu::SubgroupBroadcastOp> {
+public:
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
public:
using Base::Base;
@@ -542,6 +554,32 @@ LogicalResult GPURotateConversion::matchAndRewrite(
return success();
}
+//===----------------------------------------------------------------------===//
+// Subgroup broadcast
+//===----------------------------------------------------------------------===//
+
+LogicalResult GPUSubgroupBroadcastConversion::matchAndRewrite(
+ gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+ Value result;
+
+ switch (op.getBroadcastType()) {
+ case gpu::BroadcastType::specific_lane:
+ result = spirv::GroupNonUniformBroadcastOp::create(
+ rewriter, loc, scope, adaptor.getSrc(), adaptor.getLane());
+ break;
+ case gpu::BroadcastType::first_active_lane:
+ result = spirv::GroupNonUniformBroadcastFirstOp::create(
+ rewriter, loc, scope, adaptor.getSrc());
+ break;
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Group ops
//===----------------------------------------------------------------------===//
@@ -832,6 +870,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
patterns.add<
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
+ GPUSubgroupBroadcastConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index a1bb7f89e9183..f1a216504a7e4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -135,6 +135,18 @@ LogicalResult GroupNonUniformBroadcastOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformBroadcastFirstOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformBroadcastFirstOp::verify() {
+ spirv::Scope scope = getExecutionScope();
+ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+ return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformShuffle*
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 5383f7656a1be..2301379240a1a 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -123,6 +123,34 @@ func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid:
// -----
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformBroadcastFirst
+//===----------------------------------------------------------------------===//
+
+func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
+ // CHECK: spirv.GroupNonUniformBroadcastFirst <Workgroup> %{{.*}} : f32
+ %0 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %value : f32
+ return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_broadcast_first_scalar(%value: f32) -> f32 {
+ // CHECK: spirv.GroupNonUniformBroadcastFirst <Subgroup> %{{.*}} : f32
+ %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %value : f32
+ return %0 : f32
+}
+
+// -----
+
+func.func @group_non_uniform_broadcast_first_negative_scope(%value: f32) -> f32 {
+ // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ %0 = spirv.GroupNonUniformBroadcastFirst <Device> %value : f32
+ return %0 : f32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformElect
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/185818
More information about the Mlir-commits
mailing list