[Mlir-commits] [mlir] [mlir][spirv] Enforce `GroupNonUniformQuadSwap` direction values using an attribute (PR #178684)
Igor Wodiany
llvmlistbot at llvm.org
Thu Jan 29 07:49:50 PST 2026
https://github.com/IgWod-IMG created https://github.com/llvm/llvm-project/pull/178684
The direction can only take one of the three values {0, 1, 2} so we use a SPIR-V attribute to enforce it. This property cannot be enforced when the direction is a constant value as the verifier cannot test for non-local properties.
>From 9f8cfa6df8280a20abc3c3d3628a7a79a8364641 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Wed, 28 Jan 2026 16:44:42 +0000
Subject: [PATCH] [mlir][spirv] Enforce `GroupNonUniformQuadSwap` direction
values using an attribute
The direction can only take one of the three values {0, 1, 2} so we
use a SPIR-V attribute to enforce it. This property cannot be enforced
when the direction is a constant value as the verifier cannot test for
non-local properties.
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 9 +++++++
.../Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 4 +--
.../Dialect/SPIRV/IR/non-uniform-ops.mlir | 26 +++++++++----------
mlir/test/Target/SPIRV/non-uniform-ops.mlir | 10 +++----
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 1 +
5 files changed, 29 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index f8093d3042c50..8b9f51d38374c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4187,6 +4187,15 @@ def SPIRV_INTEL_StoreCacheControlAttr :
SPIRV_INTEL_SCC_WriteBack, SPIRV_INTEL_SCC_Streaming
]>;
+// Non-uniform quad swap direction attribute
+def SPIRV_QuadSwapDirectionAttr : SPIRV_I32EnumAttr<
+ "QuadSwapDirection", "Swap direction of a GroupNonUniformQuadSwap", "quad_swap_direction",
+ [
+ I32EnumAttrCase<"Horizontal", 0>,
+ I32EnumAttrCase<"Vertical", 1>,
+ I32EnumAttrCase<"Diagonal", 2>
+ ]>;
+
//===----------------------------------------------------------------------===//
// SPIR-V attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index d5a339115aaaa..d6e37b7e4d8a0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1672,7 +1672,7 @@ def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [
let arguments = (ins
SPIRV_ScopeAttr:$execution_scope,
AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value,
- SPIRV_SignlessOrUnsignedInt:$direction
+ SPIRV_QuadSwapDirectionAttr:$direction
);
let results = (outs
@@ -1682,7 +1682,7 @@ def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [
let hasVerifier = 0;
let assemblyFormat = [{
- $execution_scope $value $direction attr-dict `:` type($value) `,` type($direction)
+ $execution_scope $direction $value attr-dict `:` type($value)
}];
}
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 168823a6e9c2d..5383f7656a1be 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -769,28 +769,29 @@ func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
// CHECK-LABEL: @group_non_uniform_quad_swap
func.func @group_non_uniform_quad_swap(%value: f32) -> f32 {
- %dir = spirv.Constant 0 : i32
- // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : f32, i32
- %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : f32, i32
- return %0: f32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %{{.+}} : f32
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %value : f32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Vertical> %{{.+}} : f32
+ %1 = spirv.GroupNonUniformQuadSwap <Subgroup> <Vertical> %0 : f32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Diagonal> %{{.+}} : f32
+ %2 = spirv.GroupNonUniformQuadSwap <Subgroup> <Diagonal> %1 : f32
+ return %2: f32
}
// -----
// CHECK-LABEL: @group_non_uniform_quad_swap
func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
- %dir = spirv.Constant 0 : i32
- // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : vector<4xf32>, i32
- %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : vector<4xf32>, i32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %{{.+}} : vector<4xf32>
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %value : vector<4xf32>
return %0: vector<4xf32>
}
// -----
func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
- %dir = spirv.Constant 0 : i32
// expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
- %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : vector<4xf32>, i32
+ %0 = spirv.GroupNonUniformQuadSwap <Device> <Horizontal> %value : vector<4xf32>
return %0: vector<4xf32>
}
@@ -798,16 +799,15 @@ func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
%dir = spirv.Constant 0.0 : f32
- // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit signless/unsigned integer, but got 'f32'}}
- %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : vector<4xf32>, f32
+ // expected-error @+1 {{expected '<'}}
+ %0 = spirv.GroupNonUniformQuadSwap <Device> %dir %value : vector<4xf32>
return %0: vector<4xf32>
}
// -----
func.func @group_non_uniform_quad_swap(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
- %dir = spirv.Constant 0 : i32
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1 or bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got '!spirv.array<3 x i32>'}}
- %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : !spirv.array<3 x i32>, i32
+ %0 = spirv.GroupNonUniformQuadSwap <Device> <Horizontal> %value : !spirv.array<3 x i32>
return %0: !spirv.array<3 x i32>
}
diff --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
index 0c5df594ac1d7..6975836d3ddee 100644
--- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir
+++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
@@ -149,16 +149,14 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, GroupNo
}
spirv.func @group_non_uniform_quad_swap_vec(%val: vector<4xf32>) -> vector<4xf32> "None" {
- %dir = spirv.Constant 0 : i32
- // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : vector<4xf32>, i32
- %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %val %dir : vector<4xf32>, i32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Vertical> %{{.+}} : vector<4xf32>
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> <Vertical> %val : vector<4xf32>
spirv.ReturnValue %0: vector<4xf32>
}
spirv.func @group_non_uniform_quad_swap_scalar(%val: f32) -> f32 "None" {
- %dir = spirv.Constant 0 : i32
- // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : f32, i32
- %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %val %dir : f32, i32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %{{.+}} : f32
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %val : f32
spirv.ReturnValue %0: f32
}
}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 0b1771ffcee71..9cb48934b2c10 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -503,6 +503,7 @@ constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
"SPIRV_MatrixLayoutAttr",
"SPIRV_TosaExtAccTypeAttr",
"SPIRV_TosaExtNaNPropagationModeAttr",
+ "SPIRV_QuadSwapDirectionAttr",
};
/// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
More information about the Mlir-commits
mailing list