[Mlir-commits] [mlir] [mlir][spirv] Add GroupNonUniformVote instructions (PR #141294)

Darren Wihandi llvmlistbot at llvm.org
Sat May 24 23:15:38 PDT 2025


https://github.com/fairywreath updated https://github.com/llvm/llvm-project/pull/141294

>From 1361daa941948599b508163d203cb5130b2816c4 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Fri, 23 May 2025 15:02:01 -0600
Subject: [PATCH 1/2] [mlir][spirv] Add GroupNonUniformVote instructions

---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  42 ++---
 .../Dialect/SPIRV/IR/SPIRVNonUniformOps.td    | 160 ++++++++++++++++++
 mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp        |  33 ++++
 .../Dialect/SPIRV/IR/non-uniform-ops.mlir     |  73 ++++++++
 mlir/test/Target/SPIRV/non-uniform-ops.mlir   |  18 ++
 5 files changed, 307 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index cd5d201c3d5da..8fd533db83d9a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4464,6 +4464,9 @@ def SPIRV_OC_OpGroupSMax                      : I32EnumAttrCase<"OpGroupSMax", 2
 def SPIRV_OC_OpNoLine                         : I32EnumAttrCase<"OpNoLine", 317>;
 def SPIRV_OC_OpModuleProcessed                : I32EnumAttrCase<"OpModuleProcessed", 330>;
 def SPIRV_OC_OpGroupNonUniformElect           : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
+def SPIRV_OC_OpGroupNonUniformAll             : I32EnumAttrCase<"OpGroupNonUniformAll", 334>;
+def SPIRV_OC_OpGroupNonUniformAny             : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
+def SPIRV_OC_OpGroupNonUniformAllEqual        : I32EnumAttrCase<"OpGroupNonUniformAllEqual", 336>;
 def SPIRV_OC_OpGroupNonUniformBroadcast       : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>;
 def SPIRV_OC_OpGroupNonUniformBallot          : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
 def SPIRV_OC_OpGroupNonUniformBallotBitCount  : I32EnumAttrCase<"OpGroupNonUniformBallotBitCount", 342>;
@@ -4489,8 +4492,8 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor      : I32EnumAttrCase<"OpGroupNonUnifo
 def SPIRV_OC_OpGroupNonUniformLogicalAnd      : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
 def SPIRV_OC_OpGroupNonUniformLogicalOr       : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
 def SPIRV_OC_OpGroupNonUniformLogicalXor      : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
-def SPIRV_OC_OpGroupNonUniformRotateKHR       : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
 def SPIRV_OC_OpSubgroupBallotKHR              : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
+def SPIRV_OC_OpGroupNonUniformRotateKHR       : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
 def SPIRV_OC_OpSDot                           : I32EnumAttrCase<"OpSDot", 4450>;
 def SPIRV_OC_OpUDot                           : I32EnumAttrCase<"OpUDot", 4451>;
 def SPIRV_OC_OpSUDot                          : I32EnumAttrCase<"OpSUDot", 4452>;
@@ -4581,11 +4584,13 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor,
       SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge,
       SPIRV_OC_OpLabel, SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional,
-      SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue, SPIRV_OC_OpUnreachable,
-      SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd, SPIRV_OC_OpGroupFAdd,
-      SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin, SPIRV_OC_OpGroupSMin,
-      SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax, SPIRV_OC_OpGroupSMax,
-      SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed, SPIRV_OC_OpGroupNonUniformElect,
+      SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
+      SPIRV_OC_OpUnreachable, SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd,
+      SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
+      SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,
+      SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
+      SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
+      SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
       SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBallot,
       SPIRV_OC_OpGroupNonUniformBallotBitCount,
       SPIRV_OC_OpGroupNonUniformBallotFindLSB,
@@ -4599,19 +4604,18 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
       SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
       SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
-      SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpGroupNonUniformRotateKHR,
-      SPIRV_OC_OpSubgroupBallotKHR,
-      SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat,
-      SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat,
-      SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
-      SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
-      SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpEmitMeshTasksEXT,
-      SPIRV_OC_OpSetMeshOutputsEXT, SPIRV_OC_OpSubgroupBlockReadINTEL,
-      SPIRV_OC_OpSubgroupBlockWriteINTEL, SPIRV_OC_OpAssumeTrueKHR,
-      SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpConvertFToBF16INTEL,
-      SPIRV_OC_OpConvertBF16ToFINTEL, SPIRV_OC_OpControlBarrierArriveINTEL,
-      SPIRV_OC_OpControlBarrierWaitINTEL, SPIRV_OC_OpGroupIMulKHR,
-      SPIRV_OC_OpGroupFMulKHR
+      SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
+      SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
+      SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
+      SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
+      SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR,
+      SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+      SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
+      SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
+      SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
+      SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
+      SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
+      SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 3fdaff2470cba..db337f577b37e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1435,4 +1435,164 @@ def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
 
 // -----
 
+def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", []> {
+  let summary = [{
+    Evaluates a predicate for all tangled invocations within the Execution
+    scope, resulting in true if predicate evaluates to true for all tangled
+    invocations within the Execution scope, otherwise the result is false.
+  }];
+
+  let description = [{
+    Result Type must be a Boolean type.
+
+    Execution is the scope defining the scope restricted tangle affected by
+    this command. It must be Subgroup.
+
+    Predicate must be a Boolean type.
+
+    An invocation will not execute a dynamic instance of this instruction
+    (X') until all invocations in its scope restricted tangle have executed
+    all dynamic instances that are program-ordered before X'.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %predicate = ... : i1
+    %0 = spirv.GroupNonUniformAll "Subgroup" %predicate : i1
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_3>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[SPIRV_C_GroupNonUniformVote]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScopeAttr:$execution_scope,
+    SPIRV_Bool:$predicate
+  );
+
+  let results = (outs
+    SPIRV_Bool:$result
+  );
+
+  let assemblyFormat = [{
+    $execution_scope $predicate attr-dict `:` type($result)
+  }];
+}
+
+// -----
+
+def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", []> {
+  let summary = [{
+    Evaluates a predicate for all tangled invocations within the Execution
+    scope, resulting in true if predicate evaluates to true for any tangled
+    invocations within the Execution scope, otherwise the result is false.
+  }];
+
+  let description = [{
+    Result Type must be a Boolean type.
+
+    Execution is the scope defining the scope restricted tangle affected by
+    this command. It must be Subgroup.
+
+    Predicate must be a Boolean type.
+
+    An invocation will not execute a dynamic instance of this instruction
+    (X') until all invocations in its scope restricted tangle have executed
+    all dynamic instances that are program-ordered before X'.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %predicate = ... : i1
+    %0 = spirv.GroupNonUniformAny "Subgroup" %predicate : i1
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_3>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[SPIRV_C_GroupNonUniformVote]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScopeAttr:$execution_scope,
+    SPIRV_Bool:$predicate
+  );
+
+  let results = (outs
+    SPIRV_Bool:$result
+  );
+
+  let assemblyFormat = [{
+    $execution_scope $predicate attr-dict `:` type($result)
+  }];
+}
+
+// -----
+
+def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", []> {
+  let summary = [{
+    Evaluates a value for all tangled invocations within the Execution
+    scope. The result is true if Value is equal for all tangled invocations
+    within the Execution scope. Otherwise, the result is false.
+  }];
+
+  let description = [{
+    Result Type must be a Boolean type.
+
+    Execution is the scope defining the scope restricted tangle affected by
+    this command. It must be Subgroup.
+
+    Value must be a scalar or vector of floating-point type, integer type,
+    or Boolean type. The compare operation is based on this type, and if it
+    is a floating-point type, an ordered-and-equal compare is used.
+
+    An invocation will not execute a dynamic instance of this instruction
+    (X') until all invocations in its scope restricted tangle have executed
+    all dynamic instances that are program-ordered before X'.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %scalar_value = ... : f32
+    %vector_value = ... : vector<4xf32>
+    %0 = spirv.GroupNonUniformAllEqual <Subgroup> %scalar_value : f32, i1
+    %1 = spirv.GroupNonUniformAllEqual <Subgroup> %vector_value : vector<4xf32>, i1
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_3>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[SPIRV_C_GroupNonUniformVote]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScopeAttr:$execution_scope,
+    AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value
+  );
+
+  let results = (outs
+    SPIRV_Bool:$result
+  );
+
+  let assemblyFormat = [{
+    $execution_scope $value attr-dict `:` type($value) `,` type($result)
+  }];
+}
+
+// -----
+
 #endif // MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index 461d037134dae..aba876c1c80f4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -327,6 +327,39 @@ LogicalResult GroupNonUniformRotateKHROp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformAllOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformAllOp::verify() {
+  if (getExecutionScope() != spirv::Scope::Subgroup)
+    return emitOpError("execution scope must be 'Subgroup'");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformAllOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformAnyOp::verify() {
+  if (getExecutionScope() != spirv::Scope::Subgroup)
+    return emitOpError("execution scope must be 'Subgroup'");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformAllEqualOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformAllEqualOp::verify() {
+  if (getExecutionScope() != spirv::Scope::Subgroup)
+    return emitOpError("execution scope must be 'Subgroup'");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Group op verification
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 6990f2b3751f5..d7c840dc6a8ef 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -671,3 +671,76 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
   %0 = spirv.GroupNonUniformRotateKHR <Subgroup> %val, %delta, cluster_size(%five) : f32, i32, i32 -> f32
   return %0: f32
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformAll
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @group_non_uniform_all
+func.func @group_non_uniform_all(%predicate: i1) -> i1 {
+  // CHECK: %{{.+}} = spirv.GroupNonUniformAll <Subgroup> %{{.+}} : i1
+  %0 = spirv.GroupNonUniformAll <Subgroup> %predicate : i1
+  return %0: i1
+}
+
+// -----
+
+func.func @group_non_uniform_all(%predicate: i1) -> i1 {
+  // expected-error @+1 {{execution scope must be 'Subgroup'}}
+  %0 = spirv.GroupNonUniformAll <Device> %predicate : i1
+  return %0: i1
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformAny
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @group_non_uniform_any
+func.func @group_non_uniform_any(%predicate: i1) -> i1 {
+  // CHECK: %{{.+}} = spirv.GroupNonUniformAny <Subgroup> %{{.+}} : i1
+  %0 = spirv.GroupNonUniformAny <Subgroup> %predicate : i1
+  return %0: i1
+}
+
+// -----
+
+func.func @group_non_uniform_any(%predicate: i1) -> i1 {
+  // expected-error @+1 {{execution scope must be 'Subgroup'}}
+  %0 = spirv.GroupNonUniformAny <Device> %predicate : i1
+  return %0: i1
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformAllEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @group_non_uniform_all_equal
+func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
+  // CHECK: %{{.+}} = spirv.GroupNonUniformAllEqual <Subgroup> %{{.+}} : f32, i1
+  %0 = spirv.GroupNonUniformAllEqual <Subgroup> %value : f32, i1
+  return %0: i1
+}
+
+// -----
+
+// CHECK-LABEL: @group_non_uniform_all_equal
+func.func @group_non_uniform_all_equal(%value: vector<4xi32>) -> i1 {
+  // CHECK: %{{.+}} = spirv.GroupNonUniformAllEqual <Subgroup> %{{.+}} : vector<4xi32>, i1
+  %0 = spirv.GroupNonUniformAllEqual <Subgroup> %value : vector<4xi32>, i1
+  return %0: i1
+}
+
+
+// -----
+
+func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
+  // expected-error @+1 {{execution scope must be 'Subgroup'}}
+  %0 = spirv.GroupNonUniformAllEqual <Device> %value : f32, i1
+  return %0: i1
+}
diff --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
index 3e78eaf8b03ef..f29ebd86a2e03 100644
--- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir
+++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
@@ -124,4 +124,22 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.GroupNonUniformShuffleXor <Subgroup> %val, %id : f32, i32
     spirv.ReturnValue %0: f32
   }
+
+  spirv.func @group_non_uniform_all(%pred: i1) -> i1 "None" {
+    // CHECK: %{{.+}} = spirv.GroupNonUniformAll <Subgroup> %{{.+}} : i1
+    %0 = spirv.GroupNonUniformAll <Subgroup> %pred : i1
+    spirv.ReturnValue %0: i1
+  }
+
+  spirv.func @group_non_uniform_any(%pred: i1) -> i1 "None" {
+    // CHECK: %{{.+}} = spirv.GroupNonUniformAny <Subgroup> %{{.+}} : i1
+    %0 = spirv.GroupNonUniformAny <Subgroup> %pred : i1
+    spirv.ReturnValue %0: i1
+  }
+
+  spirv.func @group_non_uniform_all_equal(%val: vector<4xi32>) -> i1 "None" {
+    // CHECK: %{{.+}} = spirv.GroupNonUniformAllEqual <Subgroup> %{{.+}} : vector<4xi32>, i1
+    %0 = spirv.GroupNonUniformAllEqual <Subgroup> %val : vector<4xi32>, i1
+    spirv.ReturnValue %0: i1
+  }
 }

>From 573c744863447bfc9b3aa4de42855c95b5be4511 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Sun, 25 May 2025 02:15:24 -0400
Subject: [PATCH 2/2] Use existing attribute constraint and remove custom
 verifier functions

---
 .../Dialect/SPIRV/IR/SPIRVNonUniformOps.td    | 19 +++++++++--
 mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp        | 33 -------------------
 .../Dialect/SPIRV/IR/non-uniform-ops.mlir     |  6 ++--
 3 files changed, 19 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index db337f577b37e..7e2ab64afc6d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1435,7 +1435,9 @@ def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
 
 // -----
 
-def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", []> {
+def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", [
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
+]> {
   let summary = [{
     Evaluates a predicate for all tangled invocations within the Execution
     scope, resulting in true if predicate evaluates to true for all tangled
@@ -1480,6 +1482,8 @@ def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", []> {
     SPIRV_Bool:$result
   );
 
+  let hasVerifier = 0;
+
   let assemblyFormat = [{
     $execution_scope $predicate attr-dict `:` type($result)
   }];
@@ -1487,7 +1491,9 @@ def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", []> {
 
 // -----
 
-def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", []> {
+def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", [
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
+]> {
   let summary = [{
     Evaluates a predicate for all tangled invocations within the Execution
     scope, resulting in true if predicate evaluates to true for any tangled
@@ -1532,6 +1538,8 @@ def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", []> {
     SPIRV_Bool:$result
   );
 
+  let hasVerifier = 0;
+
   let assemblyFormat = [{
     $execution_scope $predicate attr-dict `:` type($result)
   }];
@@ -1539,7 +1547,9 @@ def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", []> {
 
 // -----
 
-def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", []> {
+def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", [
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
+]> {
   let summary = [{
     Evaluates a value for all tangled invocations within the Execution
     scope. The result is true if Value is equal for all tangled invocations
@@ -1588,6 +1598,9 @@ def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", []> {
     SPIRV_Bool:$result
   );
 
+
+  let hasVerifier = 0;
+
   let assemblyFormat = [{
     $execution_scope $value attr-dict `:` type($value) `,` type($result)
   }];
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index aba876c1c80f4..461d037134dae 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -327,39 +327,6 @@ LogicalResult GroupNonUniformRotateKHROp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformAllOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GroupNonUniformAllOp::verify() {
-  if (getExecutionScope() != spirv::Scope::Subgroup)
-    return emitOpError("execution scope must be 'Subgroup'");
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformAllOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GroupNonUniformAnyOp::verify() {
-  if (getExecutionScope() != spirv::Scope::Subgroup)
-    return emitOpError("execution scope must be 'Subgroup'");
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformAllEqualOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GroupNonUniformAllEqualOp::verify() {
-  if (getExecutionScope() != spirv::Scope::Subgroup)
-    return emitOpError("execution scope must be 'Subgroup'");
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Group op verification
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index d7c840dc6a8ef..5f56de6ad1fa9 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -688,7 +688,7 @@ func.func @group_non_uniform_all(%predicate: i1) -> i1 {
 // -----
 
 func.func @group_non_uniform_all(%predicate: i1) -> i1 {
-  // expected-error @+1 {{execution scope must be 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
   %0 = spirv.GroupNonUniformAll <Device> %predicate : i1
   return %0: i1
 }
@@ -709,7 +709,7 @@ func.func @group_non_uniform_any(%predicate: i1) -> i1 {
 // -----
 
 func.func @group_non_uniform_any(%predicate: i1) -> i1 {
-  // expected-error @+1 {{execution scope must be 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
   %0 = spirv.GroupNonUniformAny <Device> %predicate : i1
   return %0: i1
 }
@@ -740,7 +740,7 @@ func.func @group_non_uniform_all_equal(%value: vector<4xi32>) -> i1 {
 // -----
 
 func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
-  // expected-error @+1 {{execution scope must be 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
   %0 = spirv.GroupNonUniformAllEqual <Device> %value : f32, i1
   return %0: i1
 }



More information about the Mlir-commits mailing list