[Mlir-commits] [mlir] [mlir][spirv] Add spirv.GroupNonUniformBroadcastFirst Op (PR #185818)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 17 08:45:36 PDT 2026


https://github.com/hankluo6 updated https://github.com/llvm/llvm-project/pull/185818

>From cb0d77ecec0cebbe8986366e470d04fa8a05285f Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sat, 7 Mar 2026 20:01:20 -0800
Subject: [PATCH 1/6] Add lowering for gpu.subgroup_broadcast

---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  4 +-
 .../Dialect/SPIRV/IR/SPIRVNonUniformOps.td    | 49 +++++++++++++++++++
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 39 +++++++++++++++
 mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp        | 12 +++++
 4 files changed, 103 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2f189c64300ae..0e489dfc8386d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4587,6 +4587,7 @@ def SPIRV_OC_OpGroupNonUniformAll             : I32EnumAttrCase<"OpGroupNonUnifo
 def SPIRV_OC_OpGroupNonUniformAny             : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
 def SPIRV_OC_OpGroupNonUniformAllEqual        : I32EnumAttrCase<"OpGroupNonUniformAllEqual", 336>;
 def SPIRV_OC_OpGroupNonUniformBroadcast       : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>;
+def SPIRV_OC_OpGroupNonUniformBroadcastFirst  : I32EnumAttrCase<"OpGroupNonUniformBroadcastFirst", 338>;
 def SPIRV_OC_OpGroupNonUniformBallot          : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
 def SPIRV_OC_OpGroupNonUniformBallotBitCount  : I32EnumAttrCase<"OpGroupNonUniformBallotBitCount", 342>;
 def SPIRV_OC_OpGroupNonUniformBallotFindLSB   : I32EnumAttrCase<"OpGroupNonUniformBallotFindLSB", 343>;
@@ -4725,7 +4726,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
       SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
       SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
-      SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBallot,
+      SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBroadcastFirst,
+      SPIRV_OC_OpGroupNonUniformBallot,
       SPIRV_OC_OpGroupNonUniformBallotBitCount,
       SPIRV_OC_OpGroupNonUniformBallotFindLSB,
       SPIRV_OC_OpGroupNonUniformBallotFindMSB, SPIRV_OC_OpGroupNonUniformShuffle,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 784eb40141b74..a8bdb4256ed1b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -269,6 +269,55 @@ def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast",
 
 // -----
 
+def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst",
+  [Pure, AllTypesMatch<["value", "result"]>]> {
+  let summary = [{
+    Result is the Value of the invocation from the active invocations with 
+    the lowest id in the group to all active invocations in the group.
+  }];
+
+  let description = [{
+    Result Type  must be a scalar or vector of floating-point type, integer
+    type, or Boolean type.
+
+    Execution must be Workgroup or Subgroup Scope.
+
+    The type of Value must be the same as Result Type.
+
+    #### Example:
+
+    ```mlir
+    %scalar_value = ... : f32
+    %vector_value = ... : vector<4xf32>
+    %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %scalar_value : f32
+    %1 = spirv.GroupNonUniformBroadcast <Workgroup> %vector_value, %id :
+      vector<4xf32>, i32
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_3>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[SPIRV_C_GroupNonUniformBallot]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScopeAttr:$execution_scope,
+    SPIRV_Type:$value
+  );
+
+  let results = (outs
+    SPIRV_Type:$result
+  );
+
+  let assemblyFormat = [{
+    $execution_scope operands attr-dict `:` type($value)
+  }];
+}
+
+// -----
+
 def SPIRV_GroupNonUniformElectOp : SPIRV_Op<"GroupNonUniformElect", []> {
   let summary = [{
     Result is true only in the active invocation with the lowest id in the
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index c33a903d03393..d5269511d61a2 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -130,6 +130,18 @@ class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Pattern to convert a gpu.subgroup_broadcast op into a
+/// spirv.GroupNonUniformBroadcast op.
+class GPUSubgroupBroadcastConversion final
+    : public OpConversionPattern<gpu::SubgroupBroadcastOp> {
+public:
+  using Base::Base;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
 public:
   using Base::Base;
@@ -542,6 +554,32 @@ LogicalResult GPURotateConversion::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Subgroup broadcast
+//===----------------------------------------------------------------------===//
+
+LogicalResult GPUSubgroupBroadcastConversion::matchAndRewrite(
+    gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+  Value result;
+
+  switch (op.getBroadcastType()) {
+  case gpu::BroadcastType::specific_lane:
+    result = spirv::GroupNonUniformBroadcastOp::create(
+        rewriter, loc, scope, adaptor.getSrc(), adaptor.getLane());
+    break;
+  case gpu::BroadcastType::first_active_lane:
+    result = spirv::GroupNonUniformBroadcastFirstOp::create(
+        rewriter, loc, scope, adaptor.getSrc());
+    break;
+  }
+
+  rewriter.replaceOp(op, result);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Group ops
 //===----------------------------------------------------------------------===//
@@ -832,6 +870,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
   patterns.add<
       GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
       GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
+      GPUSubgroupBroadcastConversion,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
       LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index a1bb7f89e9183..f1a216504a7e4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -135,6 +135,18 @@ LogicalResult GroupNonUniformBroadcastOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformBroadcastFirstOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformBroadcastFirstOp::verify() {
+  spirv::Scope scope = getExecutionScope();
+  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+    return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.GroupNonUniformShuffle*
 //===----------------------------------------------------------------------===//

>From dae5bcf304ac7c9e0fbef34186430c93efe32f8a Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sat, 7 Mar 2026 22:54:21 -0800
Subject: [PATCH 2/6] Add test

---
 .../Dialect/SPIRV/IR/SPIRVNonUniformOps.td    |  4 +--
 .../Dialect/SPIRV/IR/non-uniform-ops.mlir     | 28 +++++++++++++++++++
 2 files changed, 30 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index a8bdb4256ed1b..1ded87fc2c090 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -290,8 +290,8 @@ def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFi
     %scalar_value = ... : f32
     %vector_value = ... : vector<4xf32>
     %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %scalar_value : f32
-    %1 = spirv.GroupNonUniformBroadcast <Workgroup> %vector_value, %id :
-      vector<4xf32>, i32
+    %1 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %vector_value :
+      vector<4xf32>, i32 i32
     ```
   }];
 
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 5383f7656a1be..2301379240a1a 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -123,6 +123,34 @@ func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid:
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformBroadcastFirst
+//===----------------------------------------------------------------------===//
+
+func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
+  // CHECK: spirv.GroupNonUniformBroadcastFirst <Workgroup> %{{.*}} : f32
+  %0 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %value : f32
+  return %0: f32
+}
+
+// -----
+
+func.func @group_non_uniform_broadcast_first_scalar(%value: f32) -> f32 {
+  // CHECK: spirv.GroupNonUniformBroadcastFirst <Subgroup> %{{.*}} : f32
+  %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %value : f32
+  return %0 : f32
+}
+
+// -----
+
+func.func @group_non_uniform_broadcast_first_negative_scope(%value: f32) -> f32 {
+  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  %0 = spirv.GroupNonUniformBroadcastFirst <Device> %value : f32
+  return %0 : f32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GroupNonUniformElect
 //===----------------------------------------------------------------------===//

>From 19c8548f38e79054f2315fa4a9be959dc849d721 Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Wed, 11 Mar 2026 23:06:31 -0700
Subject: [PATCH 3/6] Remove conversion

---
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 39 -------------------
 1 file changed, 39 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index d5269511d61a2..c33a903d03393 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -130,18 +130,6 @@ class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Pattern to convert a gpu.subgroup_broadcast op into a
-/// spirv.GroupNonUniformBroadcast op.
-class GPUSubgroupBroadcastConversion final
-    : public OpConversionPattern<gpu::SubgroupBroadcastOp> {
-public:
-  using Base::Base;
-
-  LogicalResult
-  matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
 class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
 public:
   using Base::Base;
@@ -554,32 +542,6 @@ LogicalResult GPURotateConversion::matchAndRewrite(
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// Subgroup broadcast
-//===----------------------------------------------------------------------===//
-
-LogicalResult GPUSubgroupBroadcastConversion::matchAndRewrite(
-    gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  Location loc = op.getLoc();
-  auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
-  Value result;
-
-  switch (op.getBroadcastType()) {
-  case gpu::BroadcastType::specific_lane:
-    result = spirv::GroupNonUniformBroadcastOp::create(
-        rewriter, loc, scope, adaptor.getSrc(), adaptor.getLane());
-    break;
-  case gpu::BroadcastType::first_active_lane:
-    result = spirv::GroupNonUniformBroadcastFirstOp::create(
-        rewriter, loc, scope, adaptor.getSrc());
-    break;
-  }
-
-  rewriter.replaceOp(op, result);
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Group ops
 //===----------------------------------------------------------------------===//
@@ -870,7 +832,6 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
   patterns.add<
       GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
       GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
-      GPUSubgroupBroadcastConversion,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
       LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,

>From d11bec53bc20f072e308e54bb7fe6788534d5b1e Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Thu, 12 Mar 2026 23:45:23 -0700
Subject: [PATCH 4/6] Fix ops and tests

---
 .../Dialect/SPIRV/IR/SPIRVNonUniformOps.td    | 20 ++++++++-------
 mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp        | 12 ---------
 .../Dialect/SPIRV/IR/non-uniform-ops.mlir     | 25 +++++++++++++------
 mlir/test/Target/SPIRV/non-uniform-ops.mlir   |  7 ++++++
 4 files changed, 36 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 1ded87fc2c090..396629d65f8bc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -270,17 +270,18 @@ def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast",
 // -----
 
 def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst",
-  [Pure, AllTypesMatch<["value", "result"]>]> {
+  [Pure, SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">, AllTypesMatch<["value", "result"]>]> {
   let summary = [{
-    Result is the Value of the invocation from the active invocations with 
-    the lowest id in the group to all active invocations in the group.
+    Result is the Value of the invocation from the active invocations with
+    the lowest id within the Execution scope to all active invocations
+    within the Execution scope.
   }];
 
   let description = [{
-    Result Type  must be a scalar or vector of floating-point type, integer
+    Result Type must be a scalar or vector of floating-point type, integer
     type, or Boolean type.
 
-    Execution must be Workgroup or Subgroup Scope.
+    Execution must be Subgroup Scope.
 
     The type of Value must be the same as Result Type.
 
@@ -290,8 +291,7 @@ def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFi
     %scalar_value = ... : f32
     %vector_value = ... : vector<4xf32>
     %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %scalar_value : f32
-    %1 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %vector_value :
-      vector<4xf32>, i32 i32
+    %1 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %vector_value : vector<4xf32>
     ```
   }];
 
@@ -304,13 +304,15 @@ def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFi
 
   let arguments = (ins
     SPIRV_ScopeAttr:$execution_scope,
-    SPIRV_Type:$value
+    AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value
   );
 
   let results = (outs
-    SPIRV_Type:$result
+    AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$result
   );
 
+  let hasVerifier = 0;
+
   let assemblyFormat = [{
     $execution_scope operands attr-dict `:` type($value)
   }];
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index f1a216504a7e4..a1bb7f89e9183 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -135,18 +135,6 @@ LogicalResult GroupNonUniformBroadcastOp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformBroadcastFirstOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GroupNonUniformBroadcastFirstOp::verify() {
-  spirv::Scope scope = getExecutionScope();
-  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
-    return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.GroupNonUniformShuffle*
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 2301379240a1a..9b117d00d7bc5 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -127,12 +127,6 @@ func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid:
 // spirv.GroupNonUniformBroadcastFirst
 //===----------------------------------------------------------------------===//
 
-func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
-  // CHECK: spirv.GroupNonUniformBroadcastFirst <Workgroup> %{{.*}} : f32
-  %0 = spirv.GroupNonUniformBroadcastFirst <Workgroup> %value : f32
-  return %0: f32
-}
-
 // -----
 
 func.func @group_non_uniform_broadcast_first_scalar(%value: f32) -> f32 {
@@ -143,14 +137,31 @@ func.func @group_non_uniform_broadcast_first_scalar(%value: f32) -> f32 {
 
 // -----
 
+func.func @group_non_uniform_broadcast_first_vector(%value: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: spirv.GroupNonUniformBroadcastFirst <Subgroup> %{{.*}} : vector<4xf32>
+  %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %value : vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
 func.func @group_non_uniform_broadcast_first_negative_scope(%value: f32) -> f32 {
-  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
   %0 = spirv.GroupNonUniformBroadcastFirst <Device> %value : f32
   return %0 : f32
 }
 
 // -----
 
+
+func.func @group_non_uniform_broadcast_negative_type(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1 or bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got '!spirv.array<3 x i32>'}}
+  %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %value : !spirv.array<3 x i32>
+  return %0 : !spirv.array<3 x i32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GroupNonUniformElect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
index 6975836d3ddee..1dbc4a43395a4 100644
--- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir
+++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
@@ -21,6 +21,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, GroupNo
     spirv.ReturnValue %0: f32
   }
 
+  // CHECK-LABEL: @group_non_uniform_broadcast_first
+  spirv.func @group_non_uniform_broadcast_first(%value: f32) -> f32 "None" {
+    // CHECK: spirv.GroupNonUniformBroadcastFirst <Subgroup> %{{.*}} : f32
+    %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %value : f32
+    spirv.ReturnValue %0: f32
+  }
+
   // CHECK-LABEL: @group_non_uniform_elect
   spirv.func @group_non_uniform_elect() -> i1 "None" {
     // CHECK: %{{.+}} = spirv.GroupNonUniformElect <Workgroup> : i1

>From 94e77b1b766dda7fb70952b4c248956b8d25e71e Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Tue, 17 Mar 2026 00:27:06 -0700
Subject: [PATCH 5/6] Add mismatch test

---
 mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 9b117d00d7bc5..2ca7601a1c9ad 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -154,7 +154,7 @@ func.func @group_non_uniform_broadcast_first_negative_scope(%value: f32) -> f32
 // -----
 
 
-func.func @group_non_uniform_broadcast_negative_type(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
+func.func @group_non_uniform_broadcast_first_negative_type(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
   // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1 or bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got '!spirv.array<3 x i32>'}}
   %0 = spirv.GroupNonUniformBroadcastFirst <Subgroup> %value : !spirv.array<3 x i32>
   return %0 : !spirv.array<3 x i32>
@@ -162,6 +162,14 @@ func.func @group_non_uniform_broadcast_negative_type(%value: !spirv.array<3 x i3
 
 // -----
 
+func.func @group_non_uniform_broadcast_first_negative_type_mismatch(%value: f32) -> i32 {
+  // expected-error @+1 {{'spirv.GroupNonUniformBroadcastFirst' op failed to verify that all of {value, result} have same type}}
+  %0 = "spirv.GroupNonUniformBroadcastFirst"(%value) {execution_scope = #spirv.scope<Subgroup>} : (f32) -> i32
+  return %0 : i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GroupNonUniformElect
 //===----------------------------------------------------------------------===//

>From f33ae21563b9adf0d509d5165890f3cc4aca2ea7 Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Tue, 17 Mar 2026 00:28:06 -0700
Subject: [PATCH 6/6] Refine summary

---
 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 396629d65f8bc..7ede319f85a5b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -272,12 +272,15 @@ def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast",
 def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst",
   [Pure, SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">, AllTypesMatch<["value", "result"]>]> {
   let summary = [{
+    Broadcast the value from the active invocation with the lowest id in
+    the subgroup.
+  }];
+
+  let description = [{
     Result is the Value of the invocation from the active invocations with
     the lowest id within the Execution scope to all active invocations
     within the Execution scope.
-  }];
 
-  let description = [{
     Result Type must be a scalar or vector of floating-point type, integer
     type, or Boolean type.
 



More information about the Mlir-commits mailing list