[Mlir-commits] [mlir] bc026e3 - [mlir][spirv] Add support for GroupNonUniformQuadSwap (#174747)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 7 07:13:21 PST 2026


Author: Igor Wodiany
Date: 2026-01-07T15:13:17Z
New Revision: bc026e37322249c68069460f2ad33e8992e1ed5c

URL: https://github.com/llvm/llvm-project/commit/bc026e37322249c68069460f2ad33e8992e1ed5c
DIFF: https://github.com/llvm/llvm-project/commit/bc026e37322249c68069460f2ad33e8992e1ed5c.diff

LOG: [mlir][spirv] Add support for GroupNonUniformQuadSwap (#174747)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
    mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
    mlir/test/Target/SPIRV/non-uniform-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index ecbbf39a534e1..97ee9e15a68ef 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4576,6 +4576,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor      : I32EnumAttrCase<"OpGroupNonUnifo
 def SPIRV_OC_OpGroupNonUniformLogicalAnd      : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
 def SPIRV_OC_OpGroupNonUniformLogicalOr       : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
 def SPIRV_OC_OpGroupNonUniformLogicalXor      : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpGroupNonUniformQuadSwap        : I32EnumAttrCase<"OpGroupNonUniformQuadSwap", 366>;
 def SPIRV_OC_OpTypeTensorARM                  : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
 def SPIRV_OC_OpGraphConstantARM               : I32EnumAttrCase<"OpGraphConstantARM", 4181>;
 def SPIRV_OC_OpGraphEntryPointARM             : I32EnumAttrCase<"OpGraphEntryPointARM", 4182>;
@@ -4702,7 +4703,7 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
       SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
       SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
-      SPIRV_OC_OpGroupNonUniformLogicalXor,
+      SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpGroupNonUniformQuadSwap,
       SPIRV_OC_OpTypeTensorARM,
       SPIRV_OC_OpGraphEntryPointARM, SPIRV_OC_OpGraphARM,
       SPIRV_OC_OpGraphInputARM, SPIRV_OC_OpGraphSetOutputARM, SPIRV_OC_OpGraphEndARM,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 7e2ab64afc6d0..d5a339115aaaa 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1608,4 +1608,82 @@ def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", [
 
 // -----
 
+def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [
+  SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">, AllTypesMatch<["value", "result"]>
+]> {
+  let summary = [{
+    Swap the Value of the invocation within the quad with another invocation
+    in the quad using Direction.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of floating-point type, integer type,
+    or Boolean type.
+
+    Execution is a Scope, but has no effect on the behavior of this instruction.
+    It must be Subgroup.
+
+    The type of Value must be the same as Result Type.
+
+    Direction is the kind of swap to perform.
+
+    Direction must be a scalar of integer type, whose Signedness operand is 0.
+
+    Direction must come from a constant instruction.
+
+    The value returned in Result is the value provided to Value by another invocation
+    in the same quad scope instance. The invocation providing this value is
+    determined according to Direction.
+
+    A Direction of 0 indicates a horizontal swap;
+    - Invocations with quad indices of 0 and 1 swap values
+    - Invocations with quad indices of 2 and 3 swap values
+    A Direction of 1 indicates a vertical swap;
+    - Invocations with quad indices of 0 and 2 swap values
+    - Invocations with quad indices of 1 and 3 swap values
+    A Direction of 2 indicates a diagonal swap;
+    - Invocations with quad indices of 0 and 3 swap values
+    - Invocations with quad indices of 1 and 2 swap values
+
+    Direction must be one of the above values.
+
+    If a tangled invocation within the quad reads Value from an invocation not part
+    of the tangled invocation within the same quad, the resulting value is undefined.
+
+    An invocation will not execute a dynamic instance of this instruction (X') until
+    all invocations in its quad have executed all dynamic instances that are program-ordered
+    before X'.
+
+    #### Example:
+
+    ```mlir
+    %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : f32, i32
+    %1 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : vector<4xf32>, i32
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_3>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[SPIRV_C_GroupNonUniformQuad]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScopeAttr:$execution_scope,
+    AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value,
+    SPIRV_SignlessOrUnsignedInt:$direction
+  );
+
+  let results = (outs
+    AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$result
+  );
+
+  let hasVerifier = 0;
+
+  let assemblyFormat = [{
+    $execution_scope $value $direction attr-dict `:` type($value) `,` type($direction)
+  }];
+}
+
 #endif // MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS

diff  --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index bdb2abde8d8e6..b22951f90510a 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -760,3 +760,54 @@ func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
   %0 = spirv.GroupNonUniformAllEqual <Device> %value : f32, i1
   return %0: i1
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformQuadSwap
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @group_non_uniform_quad_swap
+func.func @group_non_uniform_quad_swap(%value: f32) -> f32 {
+  %dir = spirv.Constant 0 : i32
+  // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : f32, i32
+  %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : f32, i32
+  return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @group_non_uniform_quad_swap
+func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
+  %dir = spirv.Constant 0 : i32
+  // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : vector<4xf32>, i32
+  %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : vector<4xf32>, i32
+  return %0: vector<4xf32>
+}
+
+// -----
+
+func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
+  %dir = spirv.Constant 0 : i32
+  // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
+  %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : vector<4xf32>, i32
+  return %0: vector<4xf32>
+}
+
+// -----
+
+func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
+  %dir = spirv.Constant 0.0 : f32
+  // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit signless/unsigned integer, but got 'f32'}}
+  %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : vector<4xf32>, f32
+  return %0: vector<4xf32>
+}
+
+// -----
+
+func.func @group_non_uniform_quad_swap(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
+  %dir = spirv.Constant 0 : i32
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 or bool or fixed-length vector of bool values of length 2/3/4/8/16, but got '!spirv.array<3 x i32>'}}
+  %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : !spirv.array<3 x i32>, i32
+  return %0: !spirv.array<3 x i32>
+}

diff  --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
index e7cf0a8905a81..0c5df594ac1d7 100644
--- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir
+++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
@@ -5,7 +5,7 @@
 // RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
 // RUN: %if spirv-tools %{ spirv-val %t %}
 
-spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, GroupNonUniformBallot, GroupNonUniformArithmetic, GroupNonUniformClustered, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformVote], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, GroupNonUniformBallot, GroupNonUniformArithmetic, GroupNonUniformClustered, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformVote, GroupNonUniformQuad], []> {
   // CHECK-LABEL: @group_non_uniform_ballot
   spirv.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> "None" {
     // CHECK: %{{.*}} = spirv.GroupNonUniformBallot <Workgroup> %{{.*}}: vector<4xi32>
@@ -147,4 +147,18 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, GroupNo
     %0 = spirv.GroupNonUniformAllEqual <Subgroup> %val : vector<4xi32>, i1
     spirv.ReturnValue %0: i1
   }
+
+  spirv.func @group_non_uniform_quad_swap_vec(%val: vector<4xf32>) -> vector<4xf32> "None" {
+    %dir = spirv.Constant 0 : i32
+    // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : vector<4xf32>, i32
+    %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %val %dir : vector<4xf32>, i32
+    spirv.ReturnValue %0: vector<4xf32>
+  }
+
+  spirv.func @group_non_uniform_quad_swap_scalar(%val: f32) -> f32 "None" {
+    %dir = spirv.Constant 0 : i32
+    // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : f32, i32
+    %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %val %dir : f32, i32
+    spirv.ReturnValue %0: f32
+  }
 }


        


More information about the Mlir-commits mailing list