[Mlir-commits] [mlir] bc026e3 - [mlir][spirv] Add support for GroupNonUniformQuadSwap (#174747)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 7 07:13:21 PST 2026
Author: Igor Wodiany
Date: 2026-01-07T15:13:17Z
New Revision: bc026e37322249c68069460f2ad33e8992e1ed5c
URL: https://github.com/llvm/llvm-project/commit/bc026e37322249c68069460f2ad33e8992e1ed5c
DIFF: https://github.com/llvm/llvm-project/commit/bc026e37322249c68069460f2ad33e8992e1ed5c.diff
LOG: [mlir][spirv] Add support for GroupNonUniformQuadSwap (#174747)
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
mlir/test/Target/SPIRV/non-uniform-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index ecbbf39a534e1..97ee9e15a68ef 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4576,6 +4576,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpGroupNonUniformQuadSwap : I32EnumAttrCase<"OpGroupNonUniformQuadSwap", 366>;
def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
def SPIRV_OC_OpGraphConstantARM : I32EnumAttrCase<"OpGraphConstantARM", 4181>;
def SPIRV_OC_OpGraphEntryPointARM : I32EnumAttrCase<"OpGraphEntryPointARM", 4182>;
@@ -4702,7 +4703,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
- SPIRV_OC_OpGroupNonUniformLogicalXor,
+ SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpGroupNonUniformQuadSwap,
SPIRV_OC_OpTypeTensorARM,
SPIRV_OC_OpGraphEntryPointARM, SPIRV_OC_OpGraphARM,
SPIRV_OC_OpGraphInputARM, SPIRV_OC_OpGraphSetOutputARM, SPIRV_OC_OpGraphEndARM,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 7e2ab64afc6d0..d5a339115aaaa 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1608,4 +1608,82 @@ def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", [
// -----
+def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [
+ SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">, AllTypesMatch<["value", "result"]>
+]> {
+ let summary = [{
+ Swap the Value of the invocation within the quad with another invocation
+ in the quad using Direction.
+ }];
+
+ let description = [{
+ Result Type must be a scalar or vector of floating-point type, integer type,
+ or Boolean type.
+
+ Execution is a Scope, but has no effect on the behavior of this instruction.
+ It must be Subgroup.
+
+ The type of Value must be the same as Result Type.
+
+ Direction is the kind of swap to perform.
+
+ Direction must be a scalar of integer type, whose Signedness operand is 0.
+
+ Direction must come from a constant instruction.
+
+ The value returned in Result is the value provided to Value by another invocation
+ in the same quad scope instance. The invocation providing this value is
+ determined according to Direction.
+
+ A Direction of 0 indicates a horizontal swap;
+ - Invocations with quad indices of 0 and 1 swap values
+ - Invocations with quad indices of 2 and 3 swap values
+ A Direction of 1 indicates a vertical swap;
+ - Invocations with quad indices of 0 and 2 swap values
+ - Invocations with quad indices of 1 and 3 swap values
+ A Direction of 2 indicates a diagonal swap;
+ - Invocations with quad indices of 0 and 3 swap values
+ - Invocations with quad indices of 1 and 2 swap values
+
+ Direction must be one of the above values.
+
+ If a tangled invocation within the quad reads Value from an invocation not part
+ of the tangled invocation within the same quad, the resulting value is undefined.
+
+ An invocation will not execute a dynamic instance of this instruction (X') until
+ all invocations in its quad have executed all dynamic instances that are program-ordered
+ before X'.
+
+ #### Example:
+
+ ```mlir
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : f32, i32
+ %1 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : vector<4xf32>, i32
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_3>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[]>,
+ Capability<[SPIRV_C_GroupNonUniformQuad]>
+ ];
+
+ let arguments = (ins
+ SPIRV_ScopeAttr:$execution_scope,
+ AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value,
+ SPIRV_SignlessOrUnsignedInt:$direction
+ );
+
+ let results = (outs
+ AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$result
+ );
+
+ let hasVerifier = 0;
+
+ let assemblyFormat = [{
+ $execution_scope $value $direction attr-dict `:` type($value) `,` type($direction)
+ }];
+}
+
#endif // MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index bdb2abde8d8e6..b22951f90510a 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -760,3 +760,54 @@ func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
%0 = spirv.GroupNonUniformAllEqual <Device> %value : f32, i1
return %0: i1
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformQuadSwap
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @group_non_uniform_quad_swap
+func.func @group_non_uniform_quad_swap(%value: f32) -> f32 {
+ %dir = spirv.Constant 0 : i32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : f32, i32
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : f32, i32
+ return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @group_non_uniform_quad_swap
+func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
+ %dir = spirv.Constant 0 : i32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : vector<4xf32>, i32
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : vector<4xf32>, i32
+ return %0: vector<4xf32>
+}
+
+// -----
+
+func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
+ %dir = spirv.Constant 0 : i32
+ // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+ %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : vector<4xf32>, i32
+ return %0: vector<4xf32>
+}
+
+// -----
+
+func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
+ %dir = spirv.Constant 0.0 : f32
+ // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit signless/unsigned integer, but got 'f32'}}
+ %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : vector<4xf32>, f32
+ return %0: vector<4xf32>
+}
+
+// -----
+
+func.func @group_non_uniform_quad_swap(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
+ %dir = spirv.Constant 0 : i32
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 or bool or fixed-length vector of bool values of length 2/3/4/8/16, but got '!spirv.array<3 x i32>'}}
+ %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : !spirv.array<3 x i32>, i32
+ return %0: !spirv.array<3 x i32>
+}
diff --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
index e7cf0a8905a81..0c5df594ac1d7 100644
--- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir
+++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
@@ -5,7 +5,7 @@
// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
// RUN: %if spirv-tools %{ spirv-val %t %}
-spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, GroupNonUniformBallot, GroupNonUniformArithmetic, GroupNonUniformClustered, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformVote], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, GroupNonUniformBallot, GroupNonUniformArithmetic, GroupNonUniformClustered, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformVote, GroupNonUniformQuad], []> {
// CHECK-LABEL: @group_non_uniform_ballot
spirv.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> "None" {
// CHECK: %{{.*}} = spirv.GroupNonUniformBallot <Workgroup> %{{.*}}: vector<4xi32>
@@ -147,4 +147,18 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, GroupNo
%0 = spirv.GroupNonUniformAllEqual <Subgroup> %val : vector<4xi32>, i1
spirv.ReturnValue %0: i1
}
+
+ spirv.func @group_non_uniform_quad_swap_vec(%val: vector<4xf32>) -> vector<4xf32> "None" {
+ %dir = spirv.Constant 0 : i32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : vector<4xf32>, i32
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %val %dir : vector<4xf32>, i32
+ spirv.ReturnValue %0: vector<4xf32>
+ }
+
+ spirv.func @group_non_uniform_quad_swap_scalar(%val: f32) -> f32 "None" {
+ %dir = spirv.Constant 0 : i32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : f32, i32
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %val %dir : f32, i32
+ spirv.ReturnValue %0: f32
+ }
}
More information about the Mlir-commits
mailing list