[Mlir-commits] [mlir] [mlir][spirv] Add spirv.GroupNonUniformBroadcastFirst Op (PR #185818)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 17 08:45:36 PDT 2026
https://github.com/hankluo6 updated https://github.com/llvm/llvm-project/pull/185818
>From cb0d77ecec0cebbe8986366e470d04fa8a05285f Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sat, 7 Mar 2026 20:01:20 -0800
Subject: [PATCH 1/6] Add lowering for gpu.subgroup_broadcast
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 4 +-
.../Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 49 +++++++++++++++++++
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 39 +++++++++++++++
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp | 12 +++++
4 files changed, 103 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2f189c64300ae..0e489dfc8386d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4587,6 +4587,7 @@ def SPIRV_OC_OpGroupNonUniformAll : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformAny : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
def SPIRV_OC_OpGroupNonUniformAllEqual : I32EnumAttrCase<"OpGroupNonUniformAllEqual", 336>;
def SPIRV_OC_OpGroupNonUniformBroadcast : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>;
+def SPIRV_OC_OpGroupNonUniformBroadcastFirst : I32EnumAttrCase<"OpGroupNonUniformBroadcastFirst", 338>;
def SPIRV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
def SPIRV_OC_OpGroupNonUniformBallotBitCount : I32EnumAttrCase<"OpGroupNonUniformBallotBitCount", 342>;
def SPIRV_OC_OpGroupNonUniformBallotFindLSB : I32EnumAttrCase<"OpGroupNonUniformBallotFindLSB", 343>;
@@ -4725,7 +4726,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
- SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBallot,
+ SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBroadcastFirst,
+ SPIRV_OC_OpGroupNonUniformBallot,
SPIRV_OC_OpGroupNonUniformBallotBitCount,
SPIRV_OC_OpGroupNonUniformBallotFindLSB,
SPIRV_OC_OpGroupNonUniformBallotFindMSB, SPIRV_OC_OpGroupNonUniformShuffle,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 784eb40141b74..a8bdb4256ed1b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -269,6 +269,55 @@ def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast",
// -----
+def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst",
+ [Pure, AllTypesMatch<["value", "result"]>]> {
+ let summary = [{
+ Result is the Value of the invocation from the active invocations with
+ the lowest id in the group to all active invocations in the group.
+ }];
+
+ let description = [{
+ Result Type must be a scalar or vector of floating-point type, integer
+ type, or Boolean type.
+
+ Execution must be Workgroup or Subgroup Scope.
+
+ The type of Value must be the same as Result Type.
+
+ #### Example:
+
+ ```mlir
+ %scalar_value = ... : f32
+ %vector_value = ... : vector<4xf32>
+ %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %scalar_value : f32
+ %1 = spirv.GroupNonUniformBroadcast <Workgroup> %vector_value, %id :
+ vector<4xf32>, i32
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_3>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[]>,
+ Capability<[SPIRV_C_GroupNonUniformBallot]>
+ ];
+
+ let arguments = (ins
+ SPIRV_ScopeAttr:$execution_scope,
+ SPIRV_Type:$value
+ );
+
+ let results = (outs
+ SPIRV_Type:$result
+ );
+
+ let assemblyFormat = [{
+ $execution_scope operands attr-dict `:` type($value)
+ }];
+}
+
+// -----
+
def SPIRV_GroupNonUniformElectOp : SPIRV_Op<"GroupNonUniformElect", []> {
let summary = [{
Result is true only in the active invocation with the lowest id in the
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index c33a903d03393..d5269511d61a2 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -130,6 +130,18 @@ class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Pattern to convert a gpu.subgroup_broadcast op into a
+/// spirv.GroupNonUniformBroadcast op.
+class GPUSubgroupBroadcastConversion final
+ : public OpConversionPattern<gpu::SubgroupBroadcastOp> {
+public:
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
public:
using Base::Base;
@@ -542,6 +554,32 @@ LogicalResult GPURotateConversion::matchAndRewrite(
return success();
}
+//===----------------------------------------------------------------------===//
+// Subgroup broadcast
+//===----------------------------------------------------------------------===//
+
+LogicalResult GPUSubgroupBroadcastConversion::matchAndRewrite(
+ gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+ Value result;
+
+ switch (op.getBroadcastType()) {
+ case gpu::BroadcastType::specific_lane:
+ result = spirv::GroupNonUniformBroadcastOp::create(
+ rewriter, loc, scope, adaptor.getSrc(), adaptor.getLane());
+ break;
+ case gpu::BroadcastType::first_active_lane:
+ result = spirv::GroupNonUniformBroadcastFirstOp::create(
+ rewriter, loc, scope, adaptor.getSrc());
+ break;
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Group ops
//===----------------------------------------------------------------------===//
@@ -832,6 +870,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
patterns.add<
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
+ GPUSubgroupBroadcastConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index a1bb7f89e9183..f1a216504a7e4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -135,6 +135,18 @@ LogicalResult GroupNonUniformBroadcastOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformBroadcastFirstOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformBroadcastFirstOp::verify() {
+ spirv::Scope scope = getExecutionScope();
+ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+ return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformShuffle*
//===----------------------------------------------------------------------===//
>From dae5bcf304ac7c9e0fbef34186430c93efe32f8a Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sat, 7 Mar 2026 22:54:21 -0800
Subject: [PATCH 2/6] Add test
---
.../Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 4 +--
.../Dialect/SPIRV/IR/non-uniform-ops.mlir | 28 +++++++++++++++++++
2 files changed, 30 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index a8bdb4256ed1b..1ded87fc2c090 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -290,8 +290,8 @@ def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFi
%scalar_value = ... : f32
%vector_value = ... : vector<4xf32>
%0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %scalar_value : f32
- %1 = spirv.GroupNonUniformBroadcast <Workgroup> %vector_value, %id :
- vector<4xf32>, i32
+ %1 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %vector_value :
+ vector<4xf32>, i32 i32
```
}];
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 5383f7656a1be..2301379240a1a 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -123,6 +123,34 @@ func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid:
// -----
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformBroadcastFirst
+//===----------------------------------------------------------------------===//
+
+func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
+ // CHECK: spirv.GroupNonUniformBroadcastFirst <Workgroup> %{{.*}} : f32
+ %0 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %value : f32
+ return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_broadcast_first_scalar(%value: f32) -> f32 {
+ // CHECK: spirv.GroupNonUniformBroadcastFirst <Subgroup> %{{.*}} : f32
+ %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %value : f32
+ return %0 : f32
+}
+
+// -----
+
+func.func @group_non_uniform_broadcast_first_negative_scope(%value: f32) -> f32 {
+ // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ %0 = spirv.GroupNonUniformBroadcastFirst <Device> %value : f32
+ return %0 : f32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformElect
//===----------------------------------------------------------------------===//
>From 19c8548f38e79054f2315fa4a9be959dc849d721 Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Wed, 11 Mar 2026 23:06:31 -0700
Subject: [PATCH 3/6] Remove conversion
---
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 39 -------------------
1 file changed, 39 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index d5269511d61a2..c33a903d03393 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -130,18 +130,6 @@ class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
ConversionPatternRewriter &rewriter) const override;
};
-/// Pattern to convert a gpu.subgroup_broadcast op into a
-/// spirv.GroupNonUniformBroadcast op.
-class GPUSubgroupBroadcastConversion final
- : public OpConversionPattern<gpu::SubgroupBroadcastOp> {
-public:
- using Base::Base;
-
- LogicalResult
- matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
public:
using Base::Base;
@@ -554,32 +542,6 @@ LogicalResult GPURotateConversion::matchAndRewrite(
return success();
}
-//===----------------------------------------------------------------------===//
-// Subgroup broadcast
-//===----------------------------------------------------------------------===//
-
-LogicalResult GPUSubgroupBroadcastConversion::matchAndRewrite(
- gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Location loc = op.getLoc();
- auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
- Value result;
-
- switch (op.getBroadcastType()) {
- case gpu::BroadcastType::specific_lane:
- result = spirv::GroupNonUniformBroadcastOp::create(
- rewriter, loc, scope, adaptor.getSrc(), adaptor.getLane());
- break;
- case gpu::BroadcastType::first_active_lane:
- result = spirv::GroupNonUniformBroadcastFirstOp::create(
- rewriter, loc, scope, adaptor.getSrc());
- break;
- }
-
- rewriter.replaceOp(op, result);
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Group ops
//===----------------------------------------------------------------------===//
@@ -870,7 +832,6 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
patterns.add<
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
- GPUSubgroupBroadcastConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
>From d11bec53bc20f072e308e54bb7fe6788534d5b1e Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Thu, 12 Mar 2026 23:45:23 -0700
Subject: [PATCH 4/6] Fix ops and tests
---
.../Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 20 ++++++++-------
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp | 12 ---------
.../Dialect/SPIRV/IR/non-uniform-ops.mlir | 25 +++++++++++++------
mlir/test/Target/SPIRV/non-uniform-ops.mlir | 7 ++++++
4 files changed, 36 insertions(+), 28 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 1ded87fc2c090..396629d65f8bc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -270,17 +270,18 @@ def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast",
// -----
def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst",
- [Pure, AllTypesMatch<["value", "result"]>]> {
+ [Pure, SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">, AllTypesMatch<["value", "result"]>]> {
let summary = [{
- Result is the Value of the invocation from the active invocations with
- the lowest id in the group to all active invocations in the group.
+ Result is the Value of the invocation from the active invocations with
+ the lowest id within the Execution scope to all active invocations
+ within the Execution scope.
}];
let description = [{
- Result Type must be a scalar or vector of floating-point type, integer
+ Result Type must be a scalar or vector of floating-point type, integer
type, or Boolean type.
- Execution must be Workgroup or Subgroup Scope.
+ Execution must be Subgroup Scope.
The type of Value must be the same as Result Type.
@@ -290,8 +291,7 @@ def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFi
%scalar_value = ... : f32
%vector_value = ... : vector<4xf32>
%0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %scalar_value : f32
- %1 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %vector_value :
- vector<4xf32>, i32 i32
+ %1 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %vector_value : vector<4xf32>
```
}];
@@ -304,13 +304,15 @@ def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFi
let arguments = (ins
SPIRV_ScopeAttr:$execution_scope,
- SPIRV_Type:$value
+ AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value
);
let results = (outs
- SPIRV_Type:$result
+ AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$result
);
+ let hasVerifier = 0;
+
let assemblyFormat = [{
$execution_scope operands attr-dict `:` type($value)
}];
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index f1a216504a7e4..a1bb7f89e9183 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -135,18 +135,6 @@ LogicalResult GroupNonUniformBroadcastOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformBroadcastFirstOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GroupNonUniformBroadcastFirstOp::verify() {
- spirv::Scope scope = getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformShuffle*
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 2301379240a1a..9b117d00d7bc5 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -127,12 +127,6 @@ func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid:
// spirv.GroupNonUniformBroadcastFirst
//===----------------------------------------------------------------------===//
-func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
- // CHECK: spirv.GroupNonUniformBroadcastFirst <Workgroup> %{{.*}} : f32
- %0 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %value : f32
- return %0: f32
-}
-
// -----
func.func @group_non_uniform_broadcast_first_scalar(%value: f32) -> f32 {
@@ -143,14 +137,31 @@ func.func @group_non_uniform_broadcast_first_scalar(%value: f32) -> f32 {
// -----
+func.func @group_non_uniform_broadcast_first_vector(%value: vector<4xf32>) -> vector<4xf32> {
+ // CHECK: spirv.GroupNonUniformBroadcastFirst <Subgroup> %{{.*}} : vector<4xf32>
+ %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %value : vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// -----
+
func.func @group_non_uniform_broadcast_first_negative_scope(%value: f32) -> f32 {
- // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
%0 = spirv.GroupNonUniformBroadcastFirst <Device> %value : f32
return %0 : f32
}
// -----
+
+func.func @group_non_uniform_broadcast_negative_type(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1 or bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got '!spirv.array<3 x i32>'}}
+ %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %value : !spirv.array<3 x i32>
+ return %0 : !spirv.array<3 x i32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformElect
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
index 6975836d3ddee..1dbc4a43395a4 100644
--- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir
+++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
@@ -21,6 +21,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, GroupNo
spirv.ReturnValue %0: f32
}
+ // CHECK-LABEL: @group_non_uniform_broadcast_first
+ spirv.func @group_non_uniform_broadcast_first(%value: f32) -> f32 "None" {
+ // CHECK: spirv.GroupNonUniformBroadcastFirst <Subgroup> %{{.*}} : f32
+ %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %value : f32
+ spirv.ReturnValue %0: f32
+ }
+
// CHECK-LABEL: @group_non_uniform_elect
spirv.func @group_non_uniform_elect() -> i1 "None" {
// CHECK: %{{.+}} = spirv.GroupNonUniformElect <Workgroup> : i1
>From 94e77b1b766dda7fb70952b4c248956b8d25e71e Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Tue, 17 Mar 2026 00:27:06 -0700
Subject: [PATCH 5/6] Add mismatch test
---
mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 9b117d00d7bc5..2ca7601a1c9ad 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -154,7 +154,7 @@ func.func @group_non_uniform_broadcast_first_negative_scope(%value: f32) -> f32
// -----
-func.func @group_non_uniform_broadcast_negative_type(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
+func.func @group_non_uniform_broadcast_first_negative_type(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1 or bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got '!spirv.array<3 x i32>'}}
%0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %value : !spirv.array<3 x i32>
return %0 : !spirv.array<3 x i32>
@@ -162,6 +162,14 @@ func.func @group_non_uniform_broadcast_negative_type(%value: !spirv.array<3 x i3
// -----
+func.func @group_non_uniform_broadcast_first_negative_type_mismatch(%value: f32) -> i32 {
+ // expected-error @+1 {{'spirv.GroupNonUniformBroadcastFirst' op failed to verify that all of {value, result} have same type}}
+ %0 = "spirv.GroupNonUniformBroadcastFirst"(%value) {execution_scope = #spirv.scope<Subgroup>} : (f32) -> i32
+ return %0 : i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformElect
//===----------------------------------------------------------------------===//
>From f33ae21563b9adf0d509d5165890f3cc4aca2ea7 Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Tue, 17 Mar 2026 00:28:06 -0700
Subject: [PATCH 6/6] Refine summary
---
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 396629d65f8bc..7ede319f85a5b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -272,12 +272,15 @@ def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast",
def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst",
[Pure, SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">, AllTypesMatch<["value", "result"]>]> {
let summary = [{
+ Broadcast the value from the active invocation with the lowest id in
+ the subgroup.
+ }];
+
+ let description = [{
Result is the Value of the invocation from the active invocations with
the lowest id within the Execution scope to all active invocations
within the Execution scope.
- }];
- let description = [{
Result Type must be a scalar or vector of floating-point type, integer
type, or Boolean type.
More information about the Mlir-commits
mailing list