[Mlir-commits] [mlir] [mlir][gpu][spirv] Add patterns for gpu.shuffle up/down (PR #139105)
Hsiangkai Wang
llvmlistbot at llvm.org
Thu May 8 09:12:51 PDT 2025
https://github.com/Hsiangkai created https://github.com/llvm/llvm-project/pull/139105
Convert
gpu.shuffle down %val, %offset, %width
to
spirv.GroupNonUniformRotateKHR <Subgroup> %val, %offset, cluster_size(%width)
Convert
gpu.shuffle up %val, %offset, %width
to
%down_offset = arith.subi %width, %offset
spirv.GroupNonUniformRotateKHR <Subgroup> %val, %down_offset, cluster_size(%width)
In addition, update the spirv.GroupNonUniformRotateKHR assembly format to be consistent with other gpu non-uniform operations.
>From dbc412f06e2948e62482bbb0933e754aeee51f76 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 8 May 2025 16:36:23 +0100
Subject: [PATCH] [mlir][gpu][spirv] Add patterns for gpu.shuffle up/down
Convert
gpu.shuffle down %val, %offset, %width
to
spirv.GroupNonUniformRotateKHR <Subgroup> %val, %offset, cluster_size(%width)
Convert
gpu.shuffle up %val, %offset, %width
to
%down_offset = arith.subi %width, %offset
spirv.GroupNonUniformRotateKHR <Subgroup> %val, %down_offset, cluster_size(%width)
In addition, update the spirv.GroupNonUniformRotateKHR assembly format
to be consistent with other gpu non-uniform operations.
---
.../Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 6 +-
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 12 +++-
mlir/test/Conversion/GPUToSPIRV/shuffle.mlir | 57 +++++++++++++++++++
.../Dialect/SPIRV/IR/non-uniform-ops.mlir | 18 +++---
4 files changed, 79 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 2dd3dbd28d436..3fdaff2470cba 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1404,8 +1404,8 @@ def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
```mlir
%four = spirv.Constant 4 : i32
- %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %value, %delta : f32, i32 -> f32
- %1 = spirv.GroupNonUniformRotateKHR <Workgroup>, %value, %delta,
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup> %value, %delta : f32, i32 -> f32
+ %1 = spirv.GroupNonUniformRotateKHR <Workgroup> %value, %delta,
clustersize(%four) : f32, i32, i32 -> f32
```
}];
@@ -1429,7 +1429,7 @@ def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
);
let assemblyFormat = [{
- $execution_scope `,` $value `,` $delta (`,` `cluster_size` `(` $cluster_size^ `)`)? attr-dict `:` type($value) `,` type($delta) (`,` type($cluster_size)^)? `->` type(results)
+ $execution_scope $value `,` $delta (`,` `cluster_size` `(` $cluster_size^ `)`)? attr-dict `:` type($value) `,` type($delta) (`,` type($cluster_size)^)? `->` type(results)
}];
}
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 3cc64b82950b5..fabbebef41b21 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -450,8 +450,16 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
break;
- default:
- return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
+ case gpu::ShuffleMode::DOWN:
+ result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
+ loc, scope, adaptor.getValue(), adaptor.getOffset(), shuffleOp.getWidth());
+ break;
+ case gpu::ShuffleMode::UP: {
+ Value offsetForShuffleDown = rewriter.create<arith::SubIOp>(loc, shuffleOp.getWidth(), adaptor.getOffset());
+ result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
+ loc, scope, adaptor.getValue(), offsetForShuffleDown, shuffleOp.getWidth());
+ break;
+ }
}
rewriter.replaceOp(shuffleOp, {result, trueVal});
diff --git a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
index d3d8ec0dab40f..5d7d3c81577e3 100644
--- a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
@@ -72,3 +72,60 @@ gpu.module @kernels {
}
}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformRotateKHR], []>,
+ #spirv.resource_limits<subgroup_size = 16>>
+} {
+
+gpu.module @kernels {
+ // CHECK-LABEL: spirv.func @shuffle_down()
+ gpu.func @shuffle_down() kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+ %offset = arith.constant 4 : i32
+ %width = arith.constant 16 : i32
+ %val = arith.constant 42.0 : f32
+
+ // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
+ // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
+ // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
+ // CHECK: %{{.+}} = spirv.Constant true
+ // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
+ %result, %valid = gpu.shuffle down %val, %offset, %width : f32
+ gpu.return
+ }
+}
+
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformRotateKHR], []>,
+ #spirv.resource_limits<subgroup_size = 16>>
+} {
+
+gpu.module @kernels {
+ // CHECK-LABEL: spirv.func @shuffle_up()
+ gpu.func @shuffle_up() kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+ %offset = arith.constant 4 : i32
+ %width = arith.constant 16 : i32
+ %val = arith.constant 42.0 : f32
+
+ // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
+ // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
+ // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
+ // CHECK: %{{.+}} = spirv.Constant true
+ // CHECK: %[[DOWN_OFFSET:.+]] = spirv.Constant 12 : i32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[DOWN_OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
+ %result, %valid = gpu.shuffle up %val, %offset, %width : f32
+ gpu.return
+ }
+}
+
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index bf383d3837b6e..6990f2b3751f5 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -613,8 +613,8 @@ func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
// CHECK-LABEL: @group_non_uniform_rotate_khr
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
- // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup>, %{{.+}} : f32, i32 -> f32
- %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta : f32, i32 -> f32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %{{.+}} : f32, i32 -> f32
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup> %val, %delta : f32, i32 -> f32
return %0: f32
}
@@ -622,9 +622,9 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
// CHECK-LABEL: @group_non_uniform_rotate_khr
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
- // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Workgroup>, %{{.+}} : f32, i32, i32 -> f32
+ // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Workgroup> %{{.+}} : f32, i32, i32 -> f32
%four = spirv.Constant 4 : i32
- %0 = spirv.GroupNonUniformRotateKHR <Workgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
+ %0 = spirv.GroupNonUniformRotateKHR <Workgroup> %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
return %0: f32
}
@@ -633,7 +633,7 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
%four = spirv.Constant 4 : i32
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
- %0 = spirv.GroupNonUniformRotateKHR <Device>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
+ %0 = spirv.GroupNonUniformRotateKHR <Device> %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
return %0: f32
}
@@ -642,7 +642,7 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: si32) -> f32 {
%four = spirv.Constant 4 : i32
// expected-error @+1 {{op operand #1 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}}
- %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, si32, i32 -> f32
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup> %val, %delta, cluster_size(%four) : f32, si32, i32 -> f32
return %0: f32
}
@@ -651,7 +651,7 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: si32) -> f32 {
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
%four = spirv.Constant 4 : si32
// expected-error @+1 {{op operand #2 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}}
- %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, i32, si32 -> f32
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup> %val, %delta, cluster_size(%four) : f32, i32, si32 -> f32
return %0: f32
}
@@ -659,7 +659,7 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32, %four: i32) -> f32 {
// expected-error @+1 {{cluster size operand must come from a constant op}}
- %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup> %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
return %0: f32
}
@@ -668,6 +668,6 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32, %four: i32) -> f
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
%five = spirv.Constant 5 : i32
// expected-error @+1 {{cluster size operand must be a power of two}}
- %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%five) : f32, i32, i32 -> f32
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup> %val, %delta, cluster_size(%five) : f32, i32, i32 -> f32
return %0: f32
}
More information about the Mlir-commits
mailing list