[Mlir-commits] [mlir] [mlir][spirv] Enforce execution scope for group operations in ODS (PR #196644)
Igor Wodiany
llvmlistbot at llvm.org
Fri May 8 14:02:22 PDT 2026
https://github.com/IgWod created https://github.com/llvm/llvm-project/pull/196644
This adds a new class `SPIRV_ExecutionScopeAttrIs` shared between group and non-uniform group operations.
>From 704c2d6a418e875d377b562ee65a34b209e08e09 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at amd.com>
Date: Tue, 14 Apr 2026 19:42:33 +0100
Subject: [PATCH] [mlir][spirv] Enforce execution scope for group operations in
ODS
This adds a new class `SPIRV_ExecutionScopeAttrIs` shared between
group and non-uniform group operations.
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 11 ++
.../mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td | 53 +++++++--
.../Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 95 ++++++++++------
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp | 106 ------------------
mlir/test/Dialect/SPIRV/IR/group-ops.mlir | 3 +-
.../Dialect/SPIRV/IR/non-uniform-ops.mlir | 34 +++---
6 files changed, 132 insertions(+), 170 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 3bae0fc5a1acc..742f08137f3be 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -5055,4 +5055,15 @@ def SPIRV_TosaExtRoundingModeAttr : SPIRV_I32EnumAttr<
I32EnumAttrCase<"DoubleRound", 3>,
]>;
+//===----------------------------------------------------------------------===//
+// SPIR-V Common Constraints.
+//===----------------------------------------------------------------------===//
+
+class SPIRV_ExecutionScopeAttrIs<string operand, list<string> values> : PredOpTrait<
+ operand # " must be '" # !interleave(values, "' or '") # "'",
+ CPred<"::llvm::is_contained({::mlir::spirv::Scope::" # !interleave(values, ", ::mlir::spirv::Scope::") #
+ "}, ::llvm::cast<::mlir::spirv::ScopeAttr>(getProperties()." # operand #
+ ").getValue())">
+>;
+
#endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
index 400e37432f388..047686f781bcb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
@@ -18,7 +18,8 @@
// -----
def SPIRV_GroupFMulKHROp : SPIRV_KhrVendorOp<"GroupFMul", [Pure,
- AllTypesMatch<["x", "result"]>]> {
+ AllTypesMatch<["x", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
let summary = [{
A floating-point multiplication group operation specified for all values of
'X' specified by invocations in the group.
@@ -67,13 +68,16 @@ def SPIRV_GroupFMulKHROp : SPIRV_KhrVendorOp<"GroupFMul", [Pure,
let assemblyFormat = [{
$execution_scope $group_operation operands attr-dict `:` type($x)
}];
+
+ let hasVerifier = 0;
}
// -----
def SPIRV_GroupBroadcastOp : SPIRV_Op<"GroupBroadcast",
[Pure,
- AllTypesMatch<["value", "result"]>]> {
+ AllTypesMatch<["value", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
let summary = [{
Broadcast the Value of the invocation identified by the local id LocalId
to the result of all invocations in the group.
@@ -135,7 +139,8 @@ def SPIRV_GroupBroadcastOp : SPIRV_Op<"GroupBroadcast",
// -----
def SPIRV_GroupFAddOp : SPIRV_Op<"GroupFAdd", [Pure,
- AllTypesMatch<["x", "result"]>]> {
+ AllTypesMatch<["x", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
let summary = [{
A floating-point add group operation specified for all values of X
specified by invocations in the group.
@@ -183,12 +188,15 @@ def SPIRV_GroupFAddOp : SPIRV_Op<"GroupFAdd", [Pure,
let assemblyFormat = [{
$execution_scope $group_operation operands attr-dict `:` type($x)
}];
+
+ let hasVerifier = 0;
}
// -----
def SPIRV_GroupFMaxOp : SPIRV_Op<"GroupFMax", [Pure,
- AllTypesMatch<["x", "result"]>]> {
+ AllTypesMatch<["x", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
let summary = [{
A floating-point maximum group operation specified for all values of X
specified by invocations in the group.
@@ -236,12 +244,15 @@ def SPIRV_GroupFMaxOp : SPIRV_Op<"GroupFMax", [Pure,
let assemblyFormat = [{
$execution_scope $group_operation operands attr-dict `:` type($x)
}];
+
+ let hasVerifier = 0;
}
// -----
def SPIRV_GroupFMinOp : SPIRV_Op<"GroupFMin", [Pure,
- AllTypesMatch<["x", "result"]>]> {
+ AllTypesMatch<["x", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
let summary = [{
A floating-point minimum group operation specified for all values of X
specified by invocations in the group.
@@ -289,12 +300,15 @@ def SPIRV_GroupFMinOp : SPIRV_Op<"GroupFMin", [Pure,
let assemblyFormat = [{
$execution_scope $group_operation operands attr-dict `:` type($x)
}];
+
+ let hasVerifier = 0;
}
// -----
def SPIRV_GroupIAddOp : SPIRV_Op<"GroupIAdd", [Pure,
- AllTypesMatch<["x", "result"]>]> {
+ AllTypesMatch<["x", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
let summary = [{
An integer add group operation specified for all values of X specified
by invocations in the group.
@@ -342,12 +356,15 @@ def SPIRV_GroupIAddOp : SPIRV_Op<"GroupIAdd", [Pure,
let assemblyFormat = [{
$execution_scope $group_operation operands attr-dict `:` type($x)
}];
+
+ let hasVerifier = 0;
}
// -----
def SPIRV_GroupIMulKHROp : SPIRV_KhrVendorOp<"GroupIMul", [Pure,
- AllTypesMatch<["x", "result"]>]> {
+ AllTypesMatch<["x", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
let summary = [{
An integer multiplication group operation specified for all values of 'X'
specified by invocations in the group.
@@ -395,12 +412,15 @@ def SPIRV_GroupIMulKHROp : SPIRV_KhrVendorOp<"GroupIMul", [Pure,
let assemblyFormat = [{
$execution_scope $group_operation operands attr-dict `:` type($x)
}];
+
+ let hasVerifier = 0;
}
// -----
def SPIRV_GroupSMaxOp : SPIRV_Op<"GroupSMax", [Pure,
- AllTypesMatch<["x", "result"]>]> {
+ AllTypesMatch<["x", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
let summary = [{
A signed integer maximum group operation specified for all values of X
specified by invocations in the group.
@@ -449,12 +469,15 @@ def SPIRV_GroupSMaxOp : SPIRV_Op<"GroupSMax", [Pure,
let assemblyFormat = [{
$execution_scope $group_operation operands attr-dict `:` type($x)
}];
+
+ let hasVerifier = 0;
}
// -----
def SPIRV_GroupSMinOp : SPIRV_Op<"GroupSMin", [Pure,
- AllTypesMatch<["x", "result"]>]> {
+ AllTypesMatch<["x", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
let summary = [{
A signed integer minimum group operation specified for all values of X
specified by invocations in the group.
@@ -503,12 +526,15 @@ def SPIRV_GroupSMinOp : SPIRV_Op<"GroupSMin", [Pure,
let assemblyFormat = [{
$execution_scope $group_operation operands attr-dict `:` type($x)
}];
+
+ let hasVerifier = 0;
}
// -----
def SPIRV_GroupUMaxOp : SPIRV_Op<"GroupUMax", [Pure,
- AllTypesMatch<["x", "result"]>]> {
+ AllTypesMatch<["x", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
let summary = [{
An unsigned integer maximum group operation specified for all values of
X specified by invocations in the group.
@@ -556,12 +582,15 @@ def SPIRV_GroupUMaxOp : SPIRV_Op<"GroupUMax", [Pure,
let assemblyFormat = [{
$execution_scope $group_operation operands attr-dict `:` type($x)
}];
+
+ let hasVerifier = 0;
}
// -----
def SPIRV_GroupUMinOp : SPIRV_Op<"GroupUMin", [Pure,
- AllTypesMatch<["x", "result"]>]> {
+ AllTypesMatch<["x", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
let summary = [{
An unsigned integer minimum group operation specified for all values of
X specified by invocations in the group.
@@ -610,6 +639,8 @@ def SPIRV_GroupUMinOp : SPIRV_Op<"GroupUMin", [Pure,
let assemblyFormat = [{
$execution_scope $group_operation operands attr-dict `:` type($x)
}];
+
+ let hasVerifier = 0;
}
// -----
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 7ede319f85a5b..1a0ab0ff98a8a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -14,17 +14,12 @@
#ifndef MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS
#define MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS
-class SPIRV_AttrIs<string operand, string type, string value> : PredOpTrait<
- operand # " must be " # type # " of value " # value,
- CPred<"::llvm::cast<::mlir::spirv::" # type # "Attr>(getProperties()." # operand # ").getValue() == ::mlir::spirv::" # type # "::" # value>
- >;
-
-class SPIRV_ExecutionScopeAttrIs<string operand, string value> : SPIRV_AttrIs<operand, "Scope", value>;
-
// -----
class SPIRV_GroupNonUniformArithmeticOp<string mnemonic, Type type,
- list<Trait> traits = []> : SPIRV_Op<mnemonic, traits> {
+ list<Trait> traits = []> : SPIRV_Op<mnemonic, !listconcat([
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>
+ ], traits)> {
let arguments = (ins
SPIRV_ScopeAttr:$execution_scope,
@@ -46,7 +41,9 @@ class SPIRV_GroupNonUniformArithmeticOp<string mnemonic, Type type,
// -----
-def SPIRV_GroupNonUniformBallotOp : SPIRV_Op<"GroupNonUniformBallot", []> {
+def SPIRV_GroupNonUniformBallotOp : SPIRV_Op<"GroupNonUniformBallot",[
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
let summary = [{
Result is a bitfield value combining the Predicate value from all
invocations in the group that execute the same dynamic instance of this
@@ -94,11 +91,15 @@ def SPIRV_GroupNonUniformBallotOp : SPIRV_Op<"GroupNonUniformBallot", []> {
let assemblyFormat = [{
$execution_scope $predicate attr-dict `:` type($result)
}];
+
+ let hasVerifier = 0;
}
// -----
-def SPIRV_GroupNonUniformBallotFindLSBOp : SPIRV_Op<"GroupNonUniformBallotFindLSB", []> {
+def SPIRV_GroupNonUniformBallotFindLSBOp : SPIRV_Op<"GroupNonUniformBallotFindLSB", [
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
let summary = [{
Find the least significant bit set to 1 in Value, considering only the
bits in Value required to represent all bits of the group's invocations.
@@ -150,11 +151,15 @@ def SPIRV_GroupNonUniformBallotFindLSBOp : SPIRV_Op<"GroupNonUniformBallotFindLS
let assemblyFormat = [{
$execution_scope $value attr-dict `:` type($value) `,` type($result)
}];
+
+ let hasVerifier = 0;
}
// -----
-def SPIRV_GroupNonUniformBallotFindMSBOp : SPIRV_Op<"GroupNonUniformBallotFindMSB", []> {
+def SPIRV_GroupNonUniformBallotFindMSBOp : SPIRV_Op<"GroupNonUniformBallotFindMSB", [
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
let summary = [{
Find the most significant bit set to 1 in Value, considering only the
bits in Value required to represent all bits of the group's invocations.
@@ -206,12 +211,16 @@ def SPIRV_GroupNonUniformBallotFindMSBOp : SPIRV_Op<"GroupNonUniformBallotFindMS
let assemblyFormat = [{
$execution_scope $value attr-dict `:` type($value) `,` type($result)
}];
+
+ let hasVerifier = 0;
}
// -----
-def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast",
- [Pure, AllTypesMatch<["value", "result"]>]> {
+def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast", [
+ Pure, AllTypesMatch<["value", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
let summary = [{
Result is the Value of the invocation identified by the id Id to all
active invocations in the group.
@@ -269,8 +278,10 @@ def SPIRV_GroupNonUniformBroadcastOp : SPIRV_Op<"GroupNonUniformBroadcast",
// -----
-def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst",
- [Pure, SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">, AllTypesMatch<["value", "result"]>]> {
+def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFirst", [
+ Pure, AllTypesMatch<["value", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Subgroup"]>]> {
+
let summary = [{
Broadcast the value from the active invocation with the lowest id in
the subgroup.
@@ -323,7 +334,9 @@ def SPIRV_GroupNonUniformBroadcastFirstOp : SPIRV_Op<"GroupNonUniformBroadcastFi
// -----
-def SPIRV_GroupNonUniformElectOp : SPIRV_Op<"GroupNonUniformElect", []> {
+def SPIRV_GroupNonUniformElectOp : SPIRV_Op<"GroupNonUniformElect", [
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
let summary = [{
Result is true only in the active invocation with the lowest id in the
group, otherwise result is false.
@@ -357,6 +370,8 @@ def SPIRV_GroupNonUniformElectOp : SPIRV_Op<"GroupNonUniformElect", []> {
);
let assemblyFormat = "$execution_scope attr-dict `:` type($result)";
+
+ let hasVerifier = 0;
}
// -----
@@ -739,8 +754,10 @@ def SPIRV_GroupNonUniformSMinOp : SPIRV_GroupNonUniformArithmeticOp<"GroupNonUni
// -----
-def SPIRV_GroupNonUniformShuffleOp : SPIRV_Op<"GroupNonUniformShuffle",
- [Pure, AllTypesMatch<["value", "result"]>]> {
+def SPIRV_GroupNonUniformShuffleOp : SPIRV_Op<"GroupNonUniformShuffle", [
+ Pure, AllTypesMatch<["value", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
let summary = [{
Result is the Value of the invocation identified by the id Id.
}];
@@ -791,8 +808,10 @@ def SPIRV_GroupNonUniformShuffleOp : SPIRV_Op<"GroupNonUniformShuffle",
// -----
-def SPIRV_GroupNonUniformShuffleDownOp : SPIRV_Op<"GroupNonUniformShuffleDown",
- [Pure, AllTypesMatch<["value", "result"]>]> {
+def SPIRV_GroupNonUniformShuffleDownOp : SPIRV_Op<"GroupNonUniformShuffleDown", [
+ Pure, AllTypesMatch<["value", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
let summary = [{
Result is the Value of the invocation identified by the current
invocation’s id within the group + Delta.
@@ -846,8 +865,10 @@ def SPIRV_GroupNonUniformShuffleDownOp : SPIRV_Op<"GroupNonUniformShuffleDown",
// -----
-def SPIRV_GroupNonUniformShuffleUpOp : SPIRV_Op<"GroupNonUniformShuffleUp",
- [Pure, AllTypesMatch<["value", "result"]>]> {
+def SPIRV_GroupNonUniformShuffleUpOp : SPIRV_Op<"GroupNonUniformShuffleUp", [
+ Pure, AllTypesMatch<["value", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
let summary = [{
Result is the Value of the invocation identified by the current
invocation’s id within the group - Delta.
@@ -900,8 +921,10 @@ def SPIRV_GroupNonUniformShuffleUpOp : SPIRV_Op<"GroupNonUniformShuffleUp",
// -----
-def SPIRV_GroupNonUniformShuffleXorOp : SPIRV_Op<"GroupNonUniformShuffleXor",
- [Pure, AllTypesMatch<["value", "result"]>]> {
+def SPIRV_GroupNonUniformShuffleXorOp : SPIRV_Op<"GroupNonUniformShuffleXor", [
+ Pure, AllTypesMatch<["value", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
let summary = [{
Result is the Value of the invocation identified by the current
invocation’s id within the group xor’ed with Mask.
@@ -1351,8 +1374,8 @@ def SPIRV_GroupNonUniformLogicalXorOp :
// -----
def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCount", [
- SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">,
-]> {
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Subgroup"]>]> {
+
let summary = [{
Result is the number of bits that are set to 1 in Value, considering
only the bits in Value required to represent all bits of the scope
@@ -1416,7 +1439,9 @@ def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCo
// -----
def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
- Pure, AllTypesMatch<["value", "result"]>]> {
+ Pure, AllTypesMatch<["value", "result"]>,
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Workgroup", "Subgroup"]>]> {
+
let summary = [{
Rotate values across invocations within a subgroup.
}];
@@ -1490,8 +1515,8 @@ def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
// -----
def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", [
- SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
-]> {
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Subgroup"]>]> {
+
let summary = [{
Evaluates a predicate for all tangled invocations within the Execution
scope, resulting in true if predicate evaluates to true for all tangled
@@ -1546,8 +1571,8 @@ def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", [
// -----
def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", [
- SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
-]> {
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Subgroup"]>]> {
+
let summary = [{
Evaluates a predicate for all tangled invocations within the Execution
scope, resulting in true if predicate evaluates to true for any tangled
@@ -1602,8 +1627,8 @@ def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", [
// -----
def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", [
- SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
-]> {
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Subgroup"]>]> {
+
let summary = [{
Evaluates a value for all tangled invocations within the Execution
scope. The result is true if Value is equal for all tangled invocations
@@ -1663,8 +1688,8 @@ def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", [
// -----
def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [
- SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">, AllTypesMatch<["value", "result"]>
-]> {
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", ["Subgroup"]>, AllTypesMatch<["value", "result"]>]> {
+
let summary = [{
Swap the Value of the invocation within the quad with another invocation
in the quad using Direction.
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index a1bb7f89e9183..fe6f00e9e5bca 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -22,15 +22,6 @@ namespace mlir::spirv {
template <typename OpTy>
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
- spirv::Scope scope =
- groupOp
- ->getAttrOfType<spirv::ScopeAttr>(
- OpTy::getExecutionScopeAttrName(groupOp->getName()))
- .getValue();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return groupOp->emitOpError(
- "execution scope must be 'Workgroup' or 'Subgroup'");
-
GroupOperation operation =
groupOp
->getAttrOfType<GroupOperationAttr>(
@@ -61,10 +52,6 @@ static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
//===----------------------------------------------------------------------===//
LogicalResult GroupBroadcastOp::verify() {
- spirv::Scope scope = getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
if (auto localIdTy = dyn_cast<VectorType>(getLocalid().getType()))
if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
return emitOpError("localid is a vector and can be with only "
@@ -74,51 +61,11 @@ LogicalResult GroupBroadcastOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformBallotOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GroupNonUniformBallotOp::verify() {
- spirv::Scope scope = getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformBallotFindLSBOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GroupNonUniformBallotFindLSBOp::verify() {
- spirv::Scope scope = getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformBallotFindLSBOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GroupNonUniformBallotFindMSBOp::verify() {
- spirv::Scope scope = getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformBroadcast
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformBroadcastOp::verify() {
- spirv::Scope scope = getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
// SPIR-V spec: "Before version 1.5, Id must come from a
// constant instruction.
auto targetEnv = spirv::getDefaultTargetEnv(getContext());
@@ -141,10 +88,6 @@ LogicalResult GroupNonUniformBroadcastOp::verify() {
template <typename OpTy>
static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
- spirv::Scope scope = op.getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
if (op.getOperands().back().getType().isSignedInteger())
return op.emitOpError("second operand must be a singless/unsigned integer");
@@ -164,18 +107,6 @@ LogicalResult GroupNonUniformShuffleXorOp::verify() {
return verifyGroupNonUniformShuffleOp(*this);
}
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformElectOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GroupNonUniformElectOp::verify() {
- spirv::Scope scope = getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformFAddOp
//===----------------------------------------------------------------------===//
@@ -309,10 +240,6 @@ LogicalResult GroupNonUniformLogicalXorOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformRotateKHROp::verify() {
- spirv::Scope scope = getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
if (Value clusterSizeVal = getClusterSize()) {
mlir::Operation *defOp = clusterSizeVal.getDefiningOp();
int32_t clusterSize = 0;
@@ -327,37 +254,4 @@ LogicalResult GroupNonUniformRotateKHROp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// Group op verification
-//===----------------------------------------------------------------------===//
-
-template <typename Op>
-static LogicalResult verifyGroupOp(Op op) {
- spirv::Scope scope = op.getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
- return success();
-}
-
-LogicalResult GroupIAddOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupFAddOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupFMinOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupUMinOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupSMinOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupFMaxOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupUMaxOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupSMaxOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
-
} // namespace mlir::spirv
diff --git a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
index e69a07ff885f0..1034b9f02ef52 100644
--- a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
@@ -41,7 +41,7 @@ func.func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32>
// -----
func.func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> ) -> f32 {
- // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
%0 = spirv.GroupBroadcast <Device> %value, %localid : f32, vector<3xi32>
return %0: f32
}
@@ -196,3 +196,4 @@ func.func @group_fmul(%value: f32) -> f32 {
%0 = spirv.KHR.GroupFMul <Workgroup> <Reduce> %value : f32
return %0: f32
}
+
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index fb18d69f58241..abc6964026646 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -13,7 +13,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
// -----
func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
- // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
%0 = spirv.GroupNonUniformBallot <Device> %predicate : vector<4xi32>
return %0: vector<4xi32>
}
@@ -41,7 +41,7 @@ func.func @group_non_uniform_ballot_find_lsb(%value : vector<4xi32>) -> i32 {
// -----
func.func @group_non_uniform_ballot_find_lsb(%value : vector<4xi32>) -> i32 {
- // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
%0 = spirv.GroupNonUniformBallotFindLSB <Device> %value : vector<4xi32>, i32
return %0: i32
}
@@ -69,7 +69,7 @@ func.func @group_non_uniform_ballot_find_msb(%value : vector<4xi32>) -> i32 {
// -----
func.func @group_non_uniform_ballot_find_msb(%value : vector<4xi32>) -> i32 {
- // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
%0 = spirv.GroupNonUniformBallotFindMSB <Device> %value : vector<4xi32>, i32
return %0: i32
}
@@ -108,7 +108,7 @@ func.func @group_non_uniform_broadcast_vector(%value: vector<4xf32>) -> vector<4
func.func @group_non_uniform_broadcast_negative_scope(%value: f32, %localid: i32 ) -> f32 {
%one = spirv.Constant 1 : i32
- // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
%0 = spirv.GroupNonUniformBroadcast <Device> %value, %one : f32, i32
return %0: f32
}
@@ -146,7 +146,7 @@ func.func @group_non_uniform_broadcast_first_vector(%value: vector<4xf32>) -> ve
// -----
func.func @group_non_uniform_broadcast_first_negative_scope(%value: f32) -> f32 {
- // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+ // expected-error @+1 {{execution_scope must be 'Subgroup'}}
%0 = spirv.GroupNonUniformBroadcastFirst <Device> %value : f32
return %0 : f32
}
@@ -184,7 +184,7 @@ func.func @group_non_uniform_elect() -> i1 {
// -----
func.func @group_non_uniform_elect() -> i1 {
- // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
%0 = spirv.GroupNonUniformElect <CrossDevice> : i1
return %0: i1
}
@@ -295,7 +295,7 @@ func.func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vecto
// -----
func.func @group_non_uniform_iadd_reduce(%val: i32) -> i32 {
- // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
%0 = spirv.GroupNonUniformIAdd <Device> <Reduce> %val : i32 -> i32
return %0: i32
}
@@ -395,7 +395,7 @@ func.func @group_non_uniform_shuffle2(%val: vector<2xf32>, %id: i32) -> vector<2
// -----
func.func @group_non_uniform_shuffle(%val: vector<2xf32>, %id: i32) -> vector<2xf32> {
- // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
%0 = spirv.GroupNonUniformShuffle <Device> %val, %id : vector<2xf32>, i32
return %0: vector<2xf32>
}
@@ -431,7 +431,7 @@ func.func @group_non_uniform_shuffle2(%val: vector<2xf32>, %id: i32) -> vector<2
// -----
func.func @group_non_uniform_shuffle(%val: vector<2xf32>, %id: i32) -> vector<2xf32> {
- // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
%0 = spirv.GroupNonUniformShuffleXor <Device> %val, %id : vector<2xf32>, i32
return %0: vector<2xf32>
}
@@ -467,7 +467,7 @@ func.func @group_non_uniform_shuffle2(%val: vector<2xf32>, %id: i32) -> vector<2
// -----
func.func @group_non_uniform_shuffle(%val: vector<2xf32>, %id: i32) -> vector<2xf32> {
- // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
%0 = spirv.GroupNonUniformShuffleUp <Device> %val, %id : vector<2xf32>, i32
return %0: vector<2xf32>
}
@@ -503,7 +503,7 @@ func.func @group_non_uniform_shuffle2(%val: vector<2xf32>, %id: i32) -> vector<2
// -----
func.func @group_non_uniform_shuffle(%val: vector<2xf32>, %id: i32) -> vector<2xf32> {
- // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
%0 = spirv.GroupNonUniformShuffleDown <Device> %val, %id : vector<2xf32>, i32
return %0: vector<2xf32>
}
@@ -695,7 +695,7 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
%four = spirv.Constant 4 : i32
- // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ // expected-error @+1 {{execution_scope must be 'Workgroup' or 'Subgroup'}}
%0 = spirv.GroupNonUniformRotateKHR <Device> %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
return %0: f32
}
@@ -751,7 +751,7 @@ func.func @group_non_uniform_all(%predicate: i1) -> i1 {
// -----
func.func @group_non_uniform_all(%predicate: i1) -> i1 {
- // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+ // expected-error @+1 {{execution_scope must be 'Subgroup'}}
%0 = spirv.GroupNonUniformAll <Device> %predicate : i1
return %0: i1
}
@@ -772,7 +772,7 @@ func.func @group_non_uniform_any(%predicate: i1) -> i1 {
// -----
func.func @group_non_uniform_any(%predicate: i1) -> i1 {
- // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+ // expected-error @+1 {{execution_scope must be 'Subgroup'}}
%0 = spirv.GroupNonUniformAny <Device> %predicate : i1
return %0: i1
}
@@ -803,7 +803,7 @@ func.func @group_non_uniform_all_equal(%value: vector<4xi32>) -> i1 {
// -----
func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
- // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+ // expected-error @+1 {{execution_scope must be 'Subgroup'}}
%0 = spirv.GroupNonUniformAllEqual <Device> %value : f32, i1
return %0: i1
}
@@ -837,7 +837,7 @@ func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
// -----
func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
- // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+ // expected-error @+1 {{execution_scope must be 'Subgroup'}}
%0 = spirv.GroupNonUniformQuadSwap <Device> <Horizontal> %value : vector<4xf32>
return %0: vector<4xf32>
}
@@ -874,7 +874,7 @@ func.func @group_non_uniform_ballot_bit_count(%value: vector<4xi32>) -> i32 {
// -----
func.func @group_non_uniform_ballot_bit_count_wrong_scope(%value: vector<4xi32>) -> i32 {
- // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+ // expected-error @+1 {{execution_scope must be 'Subgroup'}}
%0 = spirv.GroupNonUniformBallotBitCount <Workgroup> <Reduce> %value : vector<4xi32> -> i32
return %0: i32
}
More information about the Mlir-commits
mailing list