[Mlir-commits] [mlir] [mlir][spirv] Enforce execution scope for group operations in ODS (PR #196644)

Igor Wodiany llvmlistbot at llvm.org
Fri May 8 14:02:22 PDT 2026


https://github.com/IgWod created https://github.com/llvm/llvm-project/pull/196644

This adds a new class `SPIRV_ExecutionScopeAttrIs` shared between group and non-uniform group operations.

>From 704c2d6a418e875d377b562ee65a34b209e08e09 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at amd.com>
Date: Tue, 14 Apr 2026 19:42:33 +0100
Subject: [PATCH] [mlir][spirv] Enforce execution scope for group operations in
 ODS

This adds a new class `SPIRV_ExecutionScopeAttrIs` shared between
group and non-uniform group operations.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  11 ++
 .../mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td    |  53 +++++++--
 .../Dialect/SPIRV/IR/SPIRVNonUniformOps.td    |  95 ++++++++++------
 mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp        | 106 ------------------
 mlir/test/Dialect/SPIRV/IR/group-ops.mlir     |   3 +-
 .../Dialect/SPIRV/IR/non-uniform-ops.mlir     |  34 +++---
 6 files changed, 132 insertions(+), 170 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 3bae0fc5a1acc..742f08137f3be 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -5055,4 +5055,15 @@ def SPIRV_TosaExtRoundingModeAttr : SPIRV_I32EnumAttr<
       I32EnumAttrCase<"DoubleRound", 3>,
   ]>;
 
+//===----------------------------------------------------------------------===//
+// SPIR-V Common Constraints.
+//===----------------------------------------------------------------------===//
+
+class SPIRV_ExecutionScopeAttrIs<string operand, list<string> values> : PredOpTrait<
+  operand # " must be '" # !interleave(values, "' or '") # "'",
+  CPred<"::llvm::is_contained({::mlir::spirv::Scope::" # !interleave(values, ", ::mlir::spirv::Scope::") #
+        "}, ::llvm::cast<::mlir::spirv::ScopeAttr>(getProperties()." # operand #
+        ").getValue())">
+>;
+
 #endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
index 400e37432f388..047686f781bcb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
@@ -18,7 +18,8 @@
 // -----
 
 def SPIRV_GroupFMulKHROp : SPIRV_KhrVendorOp<"GroupFMul", [Pure,
-                               AllTypesMatch<["x", "result"]>]> {
+                               AllTypesMatch<["x", "result"]>,
+                               SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
   let summary = [{
     A floating-point multiplication group operation specified for all values of
     'X' specified by invocations in the group.
@@ -67,13 +68,16 @@ def SPIRV_GroupFMulKHROp : SPIRV_KhrVendorOp<"GroupFMul", [Pure,
   let assemblyFormat = [{
     $execution_scope $group_operation operands attr-dict `:` type($x)
   }];
+
+  let hasVerifier = 0;
 }
 
 // -----
 
 def SPIRV_GroupBroadcastOp : SPIRV_Op<"GroupBroadcast",
                               [Pure,
-                               AllTypesMatch<["value", "result"]>]> {
+                               AllTypesMatch<["value", "result"]>,
+                               SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
   let summary = [{
     Broadcast the Value of the invocation identified by the local id LocalId
     to the result of all invocations in the group.
@@ -135,7 +139,8 @@ def SPIRV_GroupBroadcastOp : SPIRV_Op<"GroupBroadcast",
 // -----
 
 def SPIRV_GroupFAddOp : SPIRV_Op<"GroupFAdd", [Pure,
-                               AllTypesMatch<["x", "result"]>]> {
+                               AllTypesMatch<["x", "result"]>,
+                               SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
   let summary = [{
     A floating-point add group operation specified for all values of X
     specified by invocations in the group.
@@ -183,12 +188,15 @@ def SPIRV_GroupFAddOp : SPIRV_Op<"GroupFAdd", [Pure,
   let assemblyFormat = [{
     $execution_scope $group_operation operands attr-dict `:` type($x)
   }];
+
+  let hasVerifier = 0;
 }
 
 // -----
 
 def SPIRV_GroupFMaxOp : SPIRV_Op<"GroupFMax", [Pure,
-                               AllTypesMatch<["x", "result"]>]> {
+                               AllTypesMatch<["x", "result"]>,
+                               SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
   let summary = [{
     A floating-point maximum group operation specified for all values of X
     specified by invocations in the group.
@@ -236,12 +244,15 @@ def SPIRV_GroupFMaxOp : SPIRV_Op<"GroupFMax", [Pure,
   let assemblyFormat = [{
     $execution_scope $group_operation operands attr-dict `:` type($x)
   }];
+
+  let hasVerifier = 0;
 }
 
 // -----
 
 def SPIRV_GroupFMinOp : SPIRV_Op<"GroupFMin", [Pure,
-                               AllTypesMatch<["x", "result"]>]> {
+                               AllTypesMatch<["x", "result"]>,
+                               SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
   let summary = [{
     A floating-point minimum group operation specified for all values of X
     specified by invocations in the group.
@@ -289,12 +300,15 @@ def SPIRV_GroupFMinOp : SPIRV_Op<"GroupFMin", [Pure,
   let assemblyFormat = [{
     $execution_scope $group_operation operands attr-dict `:` type($x)
   }];
+
+  let hasVerifier = 0;
 }
 
 // -----
 
 def SPIRV_GroupIAddOp : SPIRV_Op<"GroupIAdd", [Pure,
-                               AllTypesMatch<["x", "result"]>]> {
+                               AllTypesMatch<["x", "result"]>,
+                               SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
   let summary = [{
     An integer add group operation specified for all values of X specified
     by invocations in the group.
@@ -342,12 +356,15 @@ def SPIRV_GroupIAddOp : SPIRV_Op<"GroupIAdd", [Pure,
   let assemblyFormat = [{
     $execution_scope $group_operation operands attr-dict `:` type($x)
   }];
+
+  let hasVerifier = 0;
 }
 
 // -----
 
 def SPIRV_GroupIMulKHROp : SPIRV_KhrVendorOp<"GroupIMul", [Pure,
-                               AllTypesMatch<["x", "result"]>]> {
+                               AllTypesMatch<["x", "result"]>,
+                               SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
   let summary = [{
     An integer multiplication group operation specified for all values of 'X'
     specified by invocations in the group.
@@ -395,12 +412,15 @@ def SPIRV_GroupIMulKHROp : SPIRV_KhrVendorOp<"GroupIMul", [Pure,
   let assemblyFormat = [{
     $execution_scope $group_operation operands attr-dict `:` type($x)
   }];
+
+  let hasVerifier = 0;
 }
 
 // -----
 
 def SPIRV_GroupSMaxOp : SPIRV_Op<"GroupSMax", [Pure,
-                               AllTypesMatch<["x", "result"]>]> {
+                               AllTypesMatch<["x", "result"]>,
+                               SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
   let summary = [{
     A signed integer maximum group operation specified for all values of X
     specified by invocations in the group.
@@ -449,12 +469,15 @@ def SPIRV_GroupSMaxOp : SPIRV_Op<"GroupSMax", [Pure,
   let assemblyFormat = [{
     $execution_scope $group_operation operands attr-dict `:` type($x)
   }];
+
+  let hasVerifier = 0;
 }
 
 // -----
 
 def SPIRV_GroupSMinOp : SPIRV_Op<"GroupSMin", [Pure,
-                               AllTypesMatch<["x", "result"]>]> {
+                               AllTypesMatch<["x", "result"]>,
+                               SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
   let summary = [{
     A signed integer minimum group operation specified for all values of X
     specified by invocations in the group.
@@ -503,12 +526,15 @@ def SPIRV_GroupSMinOp : SPIRV_Op<"GroupSMin", [Pure,
   let assemblyFormat = [{
     $execution_scope $group_operation operands attr-dict `:` type($x)
   }];
+
+  let hasVerifier = 0;
 }
 
 // -----
 
 def SPIRV_GroupUMaxOp : SPIRV_Op<"GroupUMax", [Pure,
-                               AllTypesMatch<["x", "result"]>]> {
+                               AllTypesMatch<["x", "result"]>,
+                               SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
   let summary = [{
     An unsigned integer maximum group operation specified for all values of
     X specified by invocations in the group.
@@ -556,12 +582,15 @@ def SPIRV_GroupUMaxOp : SPIRV_Op<"GroupUMax", [Pure,
   let assemblyFormat = [{
     $execution_scope $group_operation operands attr-dict `:` type($x)
   }];
+
+  let hasVerifier = 0;
 }
 
 // -----
 
 def SPIRV_GroupUMinOp : SPIRV_Op<"GroupUMin", [Pure,
-                               AllTypesMatch<["x", "result"]>]> {
+                               AllTypesMatch<["x", "result"]>,
+                               SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
   let summary = [{
     An unsigned integer minimum group operation specified for all values of
     X specified by invocations in the group.
@@ -610,6 +639,8 @@ def SPIRV_GroupUMinOp : SPIRV_Op<"GroupUMin", [Pure,
   let assemblyFormat = [{
     $execution_scope $group_operation operands attr-dict `:` type($x)
   }];
+
+  let hasVerifier = 0;
 }
 
 // -----
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 7ede319f85a5b..1a0ab0ff98a8a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -14,17 +14,12 @@
 #ifndef MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS
 #define MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS
 
-class SPIRV_AttrIs<string operand, string type, string value> : PredOpTrait<
-  operand # " must be " # type # " of value " # value,
-  CPred<"::llvm::cast<::mlir::spirv::" # type # "Attr>(getProperties()." # operand # ").getValue() == ::mlir::spirv::" # type # "::" # value>
-  >;
-
-class SPIRV_ExecutionScopeAttrIs<string operand, string value> : SPIRV_AttrIs<operand, "Scope", value>;
-
 // -----
 
 class SPIRV_GroupNonUniformArithmeticOp<string mnemonic, Type type,
-      list<Trait> traits = []> : SPIRV_Op<mnemonic, traits> {
+      list<Trait> traits = []> : SPIRV_Op<mnemonic, !listconcat([
+        SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>
+      ], traits)> {
 
   let arguments = (ins
     SPIRV_ScopeAttr:$execution_scope,
@@ -46,7 +41,9 @@ class SPIRV_GroupNonUniformArithmeticOp<string mnemonic, Type type,
 
 // -----
 
-def SPIRV_GroupNonUniformBallotOp : SPIRV_Op<"GroupNonUniformBallot", []> {
+def SPIRV_GroupNonUniformBallotOp : SPIRV_Op<"GroupNonUniformBallot",[
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
   let summary = [{
     Result is a bitfield value combining the Predicate value from all
     invocations in the group that execute the same dynamic instance of this
@@ -94,11 +91,15 @@ def SPIRV_GroupNonUniformBallotOp : SPIRV_Op<"GroupNonUniformBallot", []> {
   let assemblyFormat = [{
     $execution_scope $predicate attr-dict `:` type($result)
   }];
+
+  let hasVerifier = 0;
 }
 
 // -----
 
-def SPIRV_GroupNonUniformBallotFindLSBOp : SPIRV_Op<"GroupNonUniformBallotFindLSB", []> {
+def SPIRV_GroupNonUniformBallotFindLSBOp : SPIRV_Op<"GroupNonUniformBallotFindLSB", [
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
   let summary = [{
     Find the least significant bit set to 1 in Value, considering only the
     bits in Value required to represent all bits of the group's invocations.
@@ -150,11 +151,15 @@ def SPIRV_GroupNonUniformBallotFindLSBOp : SPIRV_Op<"GroupNonUniformBallotFindLS
   let assemblyFormat = [{
     $execution_scope $value attr-dict `:` type($value) `,` type($result)
   }];
+
+  let hasVerifier = 0;
 }
 
 // -----
 
-def SPIRV_GroupNonUniformBallotFindMSBOp : SPIRV_Op<"GroupNonUniformBallotFindMSB", []> {
+def SPIRV_GroupNonUniformBallotFindMSBOp : SPIRV_Op<"GroupNonUniformBallotFindMSB", [
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
   let summary = [{
     Find the most significant bit set to 1 in Value, considering only the
     bits in Value required to represent all bits of the group's invocations.
@@ -206,12 +211,16 @@ def SPIRV_GroupNonUniformBallotFindMSBOp : SPIRV_Op<"GroupNonUniformBallotFindMS
   let assemblyFormat = [{
     $execution_scope $value attr-dict `:` type($value) `,` type($result)
   }];
+
+  let hasVerifier = 0;
 }
 
 // -----
 
-def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast",
-  [Pure, AllTypesMatch<["value", "result"]>]> {
+def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast", [
+  Pure, AllTypesMatch<["value", "result"]>,
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
   let summary = [{
     Result is the Value of the invocation identified by the id Id to all
     active invocations in the group.
@@ -269,8 +278,10 @@ def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast",
 
 // -----
 
-def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst",
-  [Pure, SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">, AllTypesMatch<["value", "result"]>]> {
+def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst", [
+  Pure, AllTypesMatch<["value", "result"]>,
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Subgroup"]>]> {
+
   let summary = [{
     Broadcast the value from the active invocation with the lowest id in
     the subgroup.
@@ -323,7 +334,9 @@ def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFi
 
 // -----
 
-def SPIRV_GroupNonUniformElectOp : SPIRV_Op<"GroupNonUniformElect", []> {
+def SPIRV_GroupNonUniformElectOp : SPIRV_Op<"GroupNonUniformElect", [
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
   let summary = [{
     Result is true only in the active invocation with the lowest id in the
     group, otherwise result is false.
@@ -357,6 +370,8 @@ def SPIRV_GroupNonUniformElectOp : SPIRV_Op<"GroupNonUniformElect", []> {
   );
 
   let assemblyFormat = "$execution_scope attr-dict `:` type($result)";
+
+  let hasVerifier = 0;
 }
 
 // -----
@@ -739,8 +754,10 @@ def SPIRV_GroupNonUniformSMinOp : SPIRV_GroupNonUniformArithmeticOp<"GroupNonUni
 
 // -----
 
-def SPIRV_GroupNonUniformShuffleOp : SPIRV_Op<"GroupNonUniformShuffle",
-    [Pure, AllTypesMatch<["value", "result"]>]> {
+def SPIRV_GroupNonUniformShuffleOp : SPIRV_Op<"GroupNonUniformShuffle", [
+  Pure, AllTypesMatch<["value", "result"]>,
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
   let summary = [{
     Result is the Value of the invocation identified by the id Id.
   }];
@@ -791,8 +808,10 @@ def SPIRV_GroupNonUniformShuffleOp : SPIRV_Op<"GroupNonUniformShuffle",
 
 // -----
 
-def SPIRV_GroupNonUniformShuffleDownOp : SPIRV_Op<"GroupNonUniformShuffleDown",
-    [Pure, AllTypesMatch<["value", "result"]>]> {
+def SPIRV_GroupNonUniformShuffleDownOp : SPIRV_Op<"GroupNonUniformShuffleDown", [
+  Pure, AllTypesMatch<["value", "result"]>,
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
   let summary = [{
     Result is the Value of the invocation identified by the current
     invocation’s id within the group + Delta.
@@ -846,8 +865,10 @@ def SPIRV_GroupNonUniformShuffleDownOp : SPIRV_Op<"GroupNonUniformShuffleDown",
 
 // -----
 
-def SPIRV_GroupNonUniformShuffleUpOp : SPIRV_Op<"GroupNonUniformShuffleUp",
-    [Pure, AllTypesMatch<["value", "result"]>]> {
+def SPIRV_GroupNonUniformShuffleUpOp : SPIRV_Op<"GroupNonUniformShuffleUp", [
+  Pure, AllTypesMatch<["value", "result"]>,
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
   let summary = [{
     Result is the Value of the invocation identified by the current
     invocation’s id within the group - Delta.
@@ -900,8 +921,10 @@ def SPIRV_GroupNonUniformShuffleUpOp : SPIRV_Op<"GroupNonUniformShuffleUp",
 
 // -----
 
-def SPIRV_GroupNonUniformShuffleXorOp : SPIRV_Op<"GroupNonUniformShuffleXor",
-    [Pure, AllTypesMatch<["value", "result"]>]> {
+def SPIRV_GroupNonUniformShuffleXorOp : SPIRV_Op<"GroupNonUniformShuffleXor", [
+  Pure, AllTypesMatch<["value", "result"]>,
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
   let summary = [{
     Result is the Value of the invocation identified by the current
     invocation’s id within the group xor’ed with Mask.
@@ -1351,8 +1374,8 @@ def SPIRV_GroupNonUniformLogicalXorOp :
 // -----
 
 def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCount", [
-  SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">,
-]> {
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Subgroup"]>]> {
+
   let summary = [{
     Result is the number of bits that are set to 1 in Value, considering
     only the bits in Value required to represent all bits of the scope
@@ -1416,7 +1439,9 @@ def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCo
 // -----
 
 def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
-  Pure, AllTypesMatch<["value", "result"]>]> {
+  Pure, AllTypesMatch<["value", "result"]>,
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
   let summary = [{
     Rotate values across invocations within a subgroup.
   }];
@@ -1490,8 +1515,8 @@ def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
 // -----
 
 def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", [
-  SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
-]> {
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Subgroup"]>]> {
+
   let summary = [{
     Evaluates a predicate for all tangled invocations within the Execution
     scope, resulting in true if predicate evaluates to true for all tangled
@@ -1546,8 +1571,8 @@ def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", [
 // -----
 
 def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", [
-  SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
-]> {
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Subgroup"]>]> {
+
   let summary = [{
     Evaluates a predicate for all tangled invocations within the Execution
     scope, resulting in true if predicate evaluates to true for any tangled
@@ -1602,8 +1627,8 @@ def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", [
 // -----
 
 def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", [
-  SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
-]> {
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Subgroup"]>]> {
+
   let summary = [{
     Evaluates a value for all tangled invocations within the Execution
     scope. The result is true if Value is equal for all tangled invocations
@@ -1663,8 +1688,8 @@ def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", [
 // -----
 
 def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [
-  SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">, AllTypesMatch<["value", "result"]>
-]> {
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Subgroup"]>, AllTypesMatch<["value", "result"]>]> {
+
   let summary = [{
     Swap the Value of the invocation within the quad with another invocation
     in the quad using Direction.
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index a1bb7f89e9183..fe6f00e9e5bca 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -22,15 +22,6 @@ namespace mlir::spirv {
 
 template <typename OpTy>
 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
-  spirv::Scope scope =
-      groupOp
-          ->getAttrOfType<spirv::ScopeAttr>(
-              OpTy::getExecutionScopeAttrName(groupOp->getName()))
-          .getValue();
-  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
-    return groupOp->emitOpError(
-        "execution scope must be 'Workgroup' or 'Subgroup'");
-
   GroupOperation operation =
       groupOp
           ->getAttrOfType<GroupOperationAttr>(
@@ -61,10 +52,6 @@ static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupBroadcastOp::verify() {
-  spirv::Scope scope = getExecutionScope();
-  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
-    return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
   if (auto localIdTy = dyn_cast<VectorType>(getLocalid().getType()))
     if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
       return emitOpError("localid is a vector and can be with only "
@@ -74,51 +61,11 @@ LogicalResult GroupBroadcastOp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformBallotOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GroupNonUniformBallotOp::verify() {
-  spirv::Scope scope = getExecutionScope();
-  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
-    return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformBallotFindLSBOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GroupNonUniformBallotFindLSBOp::verify() {
-  spirv::Scope scope = getExecutionScope();
-  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
-    return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformBallotFindLSBOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GroupNonUniformBallotFindMSBOp::verify() {
-  spirv::Scope scope = getExecutionScope();
-  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
-    return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.GroupNonUniformBroadcast
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformBroadcastOp::verify() {
-  spirv::Scope scope = getExecutionScope();
-  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
-    return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
   // SPIR-V spec: "Before version 1.5, Id must come from a
   // constant instruction.
   auto targetEnv = spirv::getDefaultTargetEnv(getContext());
@@ -141,10 +88,6 @@ LogicalResult GroupNonUniformBroadcastOp::verify() {
 
 template <typename OpTy>
 static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
-  spirv::Scope scope = op.getExecutionScope();
-  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
-    return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
   if (op.getOperands().back().getType().isSignedInteger())
     return op.emitOpError("second operand must be a singless/unsigned integer");
 
@@ -164,18 +107,6 @@ LogicalResult GroupNonUniformShuffleXorOp::verify() {
   return verifyGroupNonUniformShuffleOp(*this);
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformElectOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GroupNonUniformElectOp::verify() {
-  spirv::Scope scope = getExecutionScope();
-  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
-    return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.GroupNonUniformFAddOp
 //===----------------------------------------------------------------------===//
@@ -309,10 +240,6 @@ LogicalResult GroupNonUniformLogicalXorOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformRotateKHROp::verify() {
-  spirv::Scope scope = getExecutionScope();
-  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
-    return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
   if (Value clusterSizeVal = getClusterSize()) {
     mlir::Operation *defOp = clusterSizeVal.getDefiningOp();
     int32_t clusterSize = 0;
@@ -327,37 +254,4 @@ LogicalResult GroupNonUniformRotateKHROp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// Group op verification
-//===----------------------------------------------------------------------===//
-
-template <typename Op>
-static LogicalResult verifyGroupOp(Op op) {
-  spirv::Scope scope = op.getExecutionScope();
-  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
-    return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
-  return success();
-}
-
-LogicalResult GroupIAddOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupFAddOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupFMinOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupUMinOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupSMinOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupFMaxOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupUMaxOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupSMaxOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
-
 } // namespace mlir::spirv
diff --git a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
index e69a07ff885f0..1034b9f02ef52 100644
--- a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
@@ -41,7 +41,7 @@ func.func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32>
 // -----
 
 func.func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> ) -> f32 {
-  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
   %0 = spirv.GroupBroadcast <Device> %value, %localid : f32, vector<3xi32>
   return %0: f32
 }
@@ -196,3 +196,4 @@ func.func @group_fmul(%value: f32) -> f32 {
   %0 = spirv.KHR.GroupFMul <Workgroup> <Reduce> %value : f32
   return %0: f32
 }
+
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index fb18d69f58241..abc6964026646 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -13,7 +13,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
 // -----
 
 func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
-  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
   %0 = spirv.GroupNonUniformBallot <Device> %predicate : vector<4xi32>
   return %0: vector<4xi32>
 }
@@ -41,7 +41,7 @@ func.func @group_non_uniform_ballot_find_lsb(%value : vector<4xi32>) -> i32 {
 // -----
 
 func.func @group_non_uniform_ballot_find_lsb(%value : vector<4xi32>) -> i32 {
-  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
   %0 = spirv.GroupNonUniformBallotFindLSB <Device> %value : vector<4xi32>, i32
   return %0: i32
 }
@@ -69,7 +69,7 @@ func.func @group_non_uniform_ballot_find_msb(%value : vector<4xi32>) -> i32 {
 // -----
 
 func.func @group_non_uniform_ballot_find_msb(%value : vector<4xi32>) -> i32 {
-  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
   %0 = spirv.GroupNonUniformBallotFindMSB <Device> %value : vector<4xi32>, i32
   return %0: i32
 }
@@ -108,7 +108,7 @@ func.func @group_non_uniform_broadcast_vector(%value: vector<4xf32>) -> vector<4
 
 func.func @group_non_uniform_broadcast_negative_scope(%value: f32, %localid: i32 ) -> f32 {
   %one = spirv.Constant 1 : i32
-  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
   %0 = spirv.GroupNonUniformBroadcast <Device> %value, %one : f32, i32
   return %0: f32
 }
@@ -146,7 +146,7 @@ func.func @group_non_uniform_broadcast_first_vector(%value: vector<4xf32>) -> ve
 // -----
 
 func.func @group_non_uniform_broadcast_first_negative_scope(%value: f32) -> f32 {
-  // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+  // expected-error @+1 {{execution_scope must be 'Subgroup'}}
   %0 = spirv.GroupNonUniformBroadcastFirst <Device> %value : f32
   return %0 : f32
 }
@@ -184,7 +184,7 @@ func.func @group_non_uniform_elect() -> i1 {
 // -----
 
 func.func @group_non_uniform_elect() -> i1 {
-  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
   %0 = spirv.GroupNonUniformElect <CrossDevice> : i1
   return %0: i1
 }
@@ -295,7 +295,7 @@ func.func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vecto
 // -----
 
 func.func @group_non_uniform_iadd_reduce(%val: i32) -> i32 {
-  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
   %0 = spirv.GroupNonUniformIAdd <Device> <Reduce> %val : i32 -> i32
   return %0: i32
 }
@@ -395,7 +395,7 @@ func.func @group_non_uniform_shuffle2(%val: vector<2xf32>, %id: i32) -> vector<2
 // -----
 
 func.func @group_non_uniform_shuffle(%val: vector<2xf32>, %id: i32) -> vector<2xf32> {
-  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
   %0 = spirv.GroupNonUniformShuffle <Device> %val, %id : vector<2xf32>, i32
   return %0: vector<2xf32>
 }
@@ -431,7 +431,7 @@ func.func @group_non_uniform_shuffle2(%val: vector<2xf32>, %id: i32) -> vector<2
 // -----
 
 func.func @group_non_uniform_shuffle(%val: vector<2xf32>, %id: i32) -> vector<2xf32> {
-  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
   %0 = spirv.GroupNonUniformShuffleXor <Device> %val, %id : vector<2xf32>, i32
   return %0: vector<2xf32>
 }
@@ -467,7 +467,7 @@ func.func @group_non_uniform_shuffle2(%val: vector<2xf32>, %id: i32) -> vector<2
 // -----
 
 func.func @group_non_uniform_shuffle(%val: vector<2xf32>, %id: i32) -> vector<2xf32> {
-  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
   %0 = spirv.GroupNonUniformShuffleUp <Device> %val, %id : vector<2xf32>, i32
   return %0: vector<2xf32>
 }
@@ -503,7 +503,7 @@ func.func @group_non_uniform_shuffle2(%val: vector<2xf32>, %id: i32) -> vector<2
 // -----
 
 func.func @group_non_uniform_shuffle(%val: vector<2xf32>, %id: i32) -> vector<2xf32> {
-  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
   %0 = spirv.GroupNonUniformShuffleDown <Device> %val, %id : vector<2xf32>, i32
   return %0: vector<2xf32>
 }
@@ -695,7 +695,7 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
 
 func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
   %four = spirv.Constant 4 : i32
-  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
   %0 = spirv.GroupNonUniformRotateKHR <Device> %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
   return %0: f32
 }
@@ -751,7 +751,7 @@ func.func @group_non_uniform_all(%predicate: i1) -> i1 {
 // -----
 
 func.func @group_non_uniform_all(%predicate: i1) -> i1 {
-  // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+  // expected-error @+1 {{execution_scope must be 'Subgroup'}}
   %0 = spirv.GroupNonUniformAll <Device> %predicate : i1
   return %0: i1
 }
@@ -772,7 +772,7 @@ func.func @group_non_uniform_any(%predicate: i1) -> i1 {
 // -----
 
 func.func @group_non_uniform_any(%predicate: i1) -> i1 {
-  // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+  // expected-error @+1 {{execution_scope must be 'Subgroup'}}
   %0 = spirv.GroupNonUniformAny <Device> %predicate : i1
   return %0: i1
 }
@@ -803,7 +803,7 @@ func.func @group_non_uniform_all_equal(%value: vector<4xi32>) -> i1 {
 // -----
 
 func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
-  // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+  // expected-error @+1 {{execution_scope must be 'Subgroup'}}
   %0 = spirv.GroupNonUniformAllEqual <Device> %value : f32, i1
   return %0: i1
 }
@@ -837,7 +837,7 @@ func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
 // -----
 
 func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
-  // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+  // expected-error @+1 {{execution_scope must be 'Subgroup'}}
   %0 = spirv.GroupNonUniformQuadSwap <Device> <Horizontal> %value : vector<4xf32>
   return %0: vector<4xf32>
 }
@@ -874,7 +874,7 @@ func.func @group_non_uniform_ballot_bit_count(%value: vector<4xi32>) -> i32 {
 // -----
 
 func.func @group_non_uniform_ballot_bit_count_wrong_scope(%value: vector<4xi32>) -> i32 {
-  // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+  // expected-error @+1 {{execution_scope must be 'Subgroup'}}
   %0 = spirv.GroupNonUniformBallotBitCount <Workgroup> <Reduce> %value : vector<4xi32> -> i32
   return %0: i32
 }



More information about the Mlir-commits mailing list