[Mlir-commits] [mlir] [mlir][spirv] Enforce `GroupNonUniformQuadSwap` direction values using an attribute (PR #178684)

Igor Wodiany llvmlistbot at llvm.org
Thu Jan 29 07:49:50 PST 2026


https://github.com/IgWod-IMG created https://github.com/llvm/llvm-project/pull/178684

The direction can only take one of the three values {0, 1, 2} so we use a SPIR-V attribute to enforce it. This property cannot be enforced when the direction is a constant value as the verifier cannot test for non-local properties.

>From 9f8cfa6df8280a20abc3c3d3628a7a79a8364641 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Wed, 28 Jan 2026 16:44:42 +0000
Subject: [PATCH] [mlir][spirv] Enforce `GroupNonUniformQuadSwap` direction
 values using an attribute

The direction can only take one of the three values {0, 1, 2} so we
use a SPIR-V attribute to enforce it. This property cannot be enforced
when the direction is a constant value as the verifier cannot test for
non-local properties.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  9 +++++++
 .../Dialect/SPIRV/IR/SPIRVNonUniformOps.td    |  4 +--
 .../Dialect/SPIRV/IR/non-uniform-ops.mlir     | 26 +++++++++----------
 mlir/test/Target/SPIRV/non-uniform-ops.mlir   | 10 +++----
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      |  1 +
 5 files changed, 29 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index f8093d3042c50..8b9f51d38374c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4187,6 +4187,15 @@ def SPIRV_INTEL_StoreCacheControlAttr :
       SPIRV_INTEL_SCC_WriteBack, SPIRV_INTEL_SCC_Streaming
     ]>;
 
+// Non-uniform quad swap direction attribute
+def SPIRV_QuadSwapDirectionAttr : SPIRV_I32EnumAttr<
+  "QuadSwapDirection", "Swap direction of a GroupNonUniformQuadSwap", "quad_swap_direction",
+  [
+      I32EnumAttrCase<"Horizontal", 0>,
+      I32EnumAttrCase<"Vertical", 1>,
+      I32EnumAttrCase<"Diagonal", 2>
+  ]>;
+
 //===----------------------------------------------------------------------===//
 // SPIR-V attribute definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index d5a339115aaaa..d6e37b7e4d8a0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1672,7 +1672,7 @@ def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [
   let arguments = (ins
     SPIRV_ScopeAttr:$execution_scope,
     AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value,
-    SPIRV_SignlessOrUnsignedInt:$direction
+    SPIRV_QuadSwapDirectionAttr:$direction
   );
 
   let results = (outs
@@ -1682,7 +1682,7 @@ def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [
   let hasVerifier = 0;
 
   let assemblyFormat = [{
-    $execution_scope $value $direction attr-dict `:` type($value) `,` type($direction)
+    $execution_scope $direction $value attr-dict `:` type($value)
   }];
 }
 
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 168823a6e9c2d..5383f7656a1be 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -769,28 +769,29 @@ func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
 
 // CHECK-LABEL: @group_non_uniform_quad_swap
 func.func @group_non_uniform_quad_swap(%value: f32) -> f32 {
-  %dir = spirv.Constant 0 : i32
-  // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : f32, i32
-  %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : f32, i32
-  return %0: f32
+  // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %{{.+}} : f32
+  %0 = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %value : f32
+  // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Vertical> %{{.+}} : f32
+  %1 = spirv.GroupNonUniformQuadSwap <Subgroup> <Vertical> %0 : f32
+  // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Diagonal> %{{.+}} : f32
+  %2 = spirv.GroupNonUniformQuadSwap <Subgroup> <Diagonal> %1 : f32
+  return %2: f32
 }
 
 // -----
 
 // CHECK-LABEL: @group_non_uniform_quad_swap
 func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
-  %dir = spirv.Constant 0 : i32
-  // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : vector<4xf32>, i32
-  %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : vector<4xf32>, i32
+  // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %{{.+}} : vector<4xf32>
+  %0 = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %value : vector<4xf32>
   return %0: vector<4xf32>
 }
 
 // -----
 
 func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
-  %dir = spirv.Constant 0 : i32
   // expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
-  %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : vector<4xf32>, i32
+  %0 = spirv.GroupNonUniformQuadSwap <Device> <Horizontal> %value : vector<4xf32>
   return %0: vector<4xf32>
 }
 
@@ -798,16 +799,15 @@ func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
 
 func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
   %dir = spirv.Constant 0.0 : f32
-  // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit signless/unsigned integer, but got 'f32'}}
-  %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : vector<4xf32>, f32
+  // expected-error @+1 {{expected '<'}}
+  %0 = spirv.GroupNonUniformQuadSwap <Device> %dir %value : vector<4xf32>
   return %0: vector<4xf32>
 }
 
 // -----
 
 func.func @group_non_uniform_quad_swap(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
-  %dir = spirv.Constant 0 : i32
   // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1 or bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got '!spirv.array<3 x i32>'}}
-  %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : !spirv.array<3 x i32>, i32
+  %0 = spirv.GroupNonUniformQuadSwap <Device> <Horizontal> %value : !spirv.array<3 x i32>
   return %0: !spirv.array<3 x i32>
 }
diff --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
index 0c5df594ac1d7..6975836d3ddee 100644
--- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir
+++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
@@ -149,16 +149,14 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, GroupNo
   }
 
   spirv.func @group_non_uniform_quad_swap_vec(%val: vector<4xf32>) -> vector<4xf32> "None" {
-    %dir = spirv.Constant 0 : i32
-    // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : vector<4xf32>, i32
-    %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %val %dir : vector<4xf32>, i32
+    // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Vertical> %{{.+}} : vector<4xf32>
+    %0 = spirv.GroupNonUniformQuadSwap <Subgroup> <Vertical> %val : vector<4xf32>
     spirv.ReturnValue %0: vector<4xf32>
   }
 
   spirv.func @group_non_uniform_quad_swap_scalar(%val: f32) -> f32 "None" {
-    %dir = spirv.Constant 0 : i32
-    // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : f32, i32
-    %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %val %dir : f32, i32
+    // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %{{.+}} : f32
+    %0 = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %val : f32
     spirv.ReturnValue %0: f32
   }
 }
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 0b1771ffcee71..9cb48934b2c10 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -503,6 +503,7 @@ constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
     "SPIRV_MatrixLayoutAttr",
     "SPIRV_TosaExtAccTypeAttr",
     "SPIRV_TosaExtNaNPropagationModeAttr",
+    "SPIRV_QuadSwapDirectionAttr",
 };
 
 /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The



More information about the Mlir-commits mailing list