[Mlir-commits] [mlir] 9377c1d - [mlir][spirv] Enforce `GroupNonUniformQuadSwap` direction values using an attribute (#178684)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 29 09:35:08 PST 2026
Author: Igor Wodiany
Date: 2026-01-29T17:35:04Z
New Revision: 9377c1d68b64d9c3dc70f5e04d2e7843c680ea76
URL: https://github.com/llvm/llvm-project/commit/9377c1d68b64d9c3dc70f5e04d2e7843c680ea76
DIFF: https://github.com/llvm/llvm-project/commit/9377c1d68b64d9c3dc70f5e04d2e7843c680ea76.diff
LOG: [mlir][spirv] Enforce `GroupNonUniformQuadSwap` direction values using an attribute (#178684)
The direction can only take one of the three values {0, 1, 2} so we use
a SPIR-V attribute to enforce it. This property cannot be enforced when
the direction is a constant value as the verifier cannot test for
non-local properties.
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
mlir/test/Target/SPIRV/non-uniform-ops.mlir
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index f8093d3042c50..8b9f51d38374c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4187,6 +4187,15 @@ def SPIRV_INTEL_StoreCacheControlAttr :
SPIRV_INTEL_SCC_WriteBack, SPIRV_INTEL_SCC_Streaming
]>;
+// Non-uniform quad swap direction attribute
+def SPIRV_QuadSwapDirectionAttr : SPIRV_I32EnumAttr<
+ "QuadSwapDirection", "Swap direction of a GroupNonUniformQuadSwap", "quad_swap_direction",
+ [
+ I32EnumAttrCase<"Horizontal", 0>,
+ I32EnumAttrCase<"Vertical", 1>,
+ I32EnumAttrCase<"Diagonal", 2>
+ ]>;
+
//===----------------------------------------------------------------------===//
// SPIR-V attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index d5a339115aaaa..37df66372f51c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1657,8 +1657,8 @@ def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [
#### Example:
```mlir
- %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : f32, i32
- %1 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : vector<4xf32>, i32
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %value : f32
+ %1 = spirv.GroupNonUniformQuadSwap <Subgroup> <Vertical> %value : vector<4xf32>
```
}];
@@ -1672,7 +1672,7 @@ def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [
let arguments = (ins
SPIRV_ScopeAttr:$execution_scope,
AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value,
- SPIRV_SignlessOrUnsignedInt:$direction
+ SPIRV_QuadSwapDirectionAttr:$direction
);
let results = (outs
@@ -1682,7 +1682,7 @@ def SPIRV_GroupNonUniformQuadSwapOp : SPIRV_Op<"GroupNonUniformQuadSwap", [
let hasVerifier = 0;
let assemblyFormat = [{
- $execution_scope $value $direction attr-dict `:` type($value) `,` type($direction)
+ $execution_scope $direction $value attr-dict `:` type($value)
}];
}
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 168823a6e9c2d..5383f7656a1be 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -769,28 +769,29 @@ func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
// CHECK-LABEL: @group_non_uniform_quad_swap
func.func @group_non_uniform_quad_swap(%value: f32) -> f32 {
- %dir = spirv.Constant 0 : i32
- // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : f32, i32
- %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : f32, i32
- return %0: f32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %{{.+}} : f32
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %value : f32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Vertical> %{{.+}} : f32
+ %1 = spirv.GroupNonUniformQuadSwap <Subgroup> <Vertical> %0 : f32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Diagonal> %{{.+}} : f32
+ %2 = spirv.GroupNonUniformQuadSwap <Subgroup> <Diagonal> %1 : f32
+ return %2: f32
}
// -----
// CHECK-LABEL: @group_non_uniform_quad_swap
func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
- %dir = spirv.Constant 0 : i32
- // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : vector<4xf32>, i32
- %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %value %dir : vector<4xf32>, i32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %{{.+}} : vector<4xf32>
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %value : vector<4xf32>
return %0: vector<4xf32>
}
// -----
func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
- %dir = spirv.Constant 0 : i32
// expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
- %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : vector<4xf32>, i32
+ %0 = spirv.GroupNonUniformQuadSwap <Device> <Horizontal> %value : vector<4xf32>
return %0: vector<4xf32>
}
@@ -798,16 +799,15 @@ func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
%dir = spirv.Constant 0.0 : f32
- // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit signless/unsigned integer, but got 'f32'}}
- %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : vector<4xf32>, f32
+ // expected-error @+1 {{expected '<'}}
+ %0 = spirv.GroupNonUniformQuadSwap <Device> %dir %value : vector<4xf32>
return %0: vector<4xf32>
}
// -----
func.func @group_non_uniform_quad_swap(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
- %dir = spirv.Constant 0 : i32
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1 or bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got '!spirv.array<3 x i32>'}}
- %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : !spirv.array<3 x i32>, i32
+ %0 = spirv.GroupNonUniformQuadSwap <Device> <Horizontal> %value : !spirv.array<3 x i32>
return %0: !spirv.array<3 x i32>
}
diff --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
index 0c5df594ac1d7..6975836d3ddee 100644
--- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir
+++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
@@ -149,16 +149,14 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, Linkage, GroupNo
}
spirv.func @group_non_uniform_quad_swap_vec(%val: vector<4xf32>) -> vector<4xf32> "None" {
- %dir = spirv.Constant 0 : i32
- // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : vector<4xf32>, i32
- %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %val %dir : vector<4xf32>, i32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Vertical> %{{.+}} : vector<4xf32>
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> <Vertical> %val : vector<4xf32>
spirv.ReturnValue %0: vector<4xf32>
}
spirv.func @group_non_uniform_quad_swap_scalar(%val: f32) -> f32 "None" {
- %dir = spirv.Constant 0 : i32
- // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> %{{.+}} %{{.+}} : f32, i32
- %0 = spirv.GroupNonUniformQuadSwap <Subgroup> %val %dir : f32, i32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %{{.+}} : f32
+ %0 = spirv.GroupNonUniformQuadSwap <Subgroup> <Horizontal> %val : f32
spirv.ReturnValue %0: f32
}
}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 0b1771ffcee71..9cb48934b2c10 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -503,6 +503,7 @@ constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
"SPIRV_MatrixLayoutAttr",
"SPIRV_TosaExtAccTypeAttr",
"SPIRV_TosaExtNaNPropagationModeAttr",
+ "SPIRV_QuadSwapDirectionAttr",
};
/// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
More information about the Mlir-commits
mailing list