[Mlir-commits] [mlir] [mlir][gpu-to-spirv] Add support for `gpu.func` workgroup attributions (PR #189744)
Amir Mohammad Tavakkoli
llvmlistbot at llvm.org
Fri Apr 10 09:56:14 PDT 2026
https://github.com/tavakkoliamirmohammad updated https://github.com/llvm/llvm-project/pull/189744
>From 35385e0fb2ecc4641f554f79c1d4b91983d4a1d6 Mon Sep 17 00:00:00 2001
From: Amir Mohammad Tavakkoli <tavakkoli.amirmohammad at gmail.com>
Date: Tue, 31 Mar 2026 14:34:36 -0600
Subject: [PATCH] [mlir][gpu-to-spirv] Add support for gpu.func workgroup
attributions
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
The GPU-to-SPIR-V conversion pass currently ignores workgroup attributions
on gpu.func, causing any kernel using shared/workgroup memory to fail
conversion. This patch adds support by:
1. Creating spirv.GlobalVariable ops at module scope with Workgroup
storage class for each workgroup attribution.
2. Creating spirv.mlir.addressof ops at function entry to reference
the global variables, and remapping the workgroup block arguments
to these addressof results via the signature conversion.
3. Fixing getDefaultABIAttrs() and the fallback ABI path in
GPUFuncOpConversion to only iterate over regular function arguments
(not workgroup/private attributions), preventing incorrect
descriptor set/binding assignment to device-local memory.
4. Skipping the workgroup_attributions attribute when copying
attributes from gpu.func to spirv.func.
This enables the full SPIR-V cross-compilation path for shared memory
GPU kernels, targeting Vulkan, WebGPU (via SPIR-V → WGSL), Metal
(via SPIR-V → MSL), and OpenCL 2.1+.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply at anthropic.com>
---
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 70 ++++++++++-
.../Conversion/GPUToSPIRV/gpu-to-spirv.mlir | 114 ++++++++++++++++++
.../GPUToSPIRV/workgroup-memory.mlir | 76 ++++++++++++
3 files changed, 255 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Conversion/GPUToSPIRV/workgroup-memory.mlir
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 0a8d80e232456..b7e8291075e83 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -265,10 +265,26 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
"or none of them");
return nullptr;
}
+
+ // Collect workgroup attribution types before inlining moves the body.
+ SmallVector<Type> workgroupTypes;
+ for (BlockArgument wgAttr : funcOp.getWorkgroupAttributions()) {
+ Type convertedType = typeConverter.convertType(wgAttr.getType());
+ if (!convertedType) {
+ funcOp.emitError("unable to convert workgroup attribution type");
+ return nullptr;
+ }
+ workgroupTypes.push_back(convertedType);
+ }
+
// Update the signature to valid SPIR-V types and add the ABI
// attributes. These will be "materialized" by using the
- // LowerABIAttributesPass.
- TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
+ // LowerABIAttributesPass. The signature conversion covers both regular
+ // function arguments and workgroup attributions.
+ unsigned numFuncArgs = fnType.getNumInputs();
+ unsigned numWorkgroupAttrs = funcOp.getNumWorkgroupAttributions();
+ TypeConverter::SignatureConversion signatureConverter(numFuncArgs +
+ numWorkgroupAttrs);
{
for (const auto &argType :
enumerate(funcOp.getFunctionType().getInputs())) {
@@ -283,13 +299,47 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {}));
for (const auto &namedAttr : funcOp->getAttrs()) {
if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
- namedAttr.getName() == SymbolTable::getSymbolAttrName())
+ namedAttr.getName() == SymbolTable::getSymbolAttrName() ||
+ namedAttr.getName() ==
+ gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
continue;
newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
}
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
+
+ // Handle workgroup attributions by creating module-scope global variables
+ // and remapping workgroup block arguments to spirv.mlir.addressof results.
+ if (numWorkgroupAttrs > 0) {
+ Location loc = funcOp.getLoc();
+
+ // Create global variables at module scope (before the function).
+ SmallVector<spirv::GlobalVariableOp> workgroupGlobals;
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(newFuncOp);
+ for (auto [idx, type] : llvm::enumerate(workgroupTypes)) {
+ std::string name = std::string("__workgroup_mem__") +
+ funcOp.getName().str() + "_" + std::to_string(idx);
+ auto globalOp = spirv::GlobalVariableOp::create(
+ rewriter, loc, type, name, /*initializer=*/nullptr);
+ workgroupGlobals.push_back(globalOp);
+ }
+ }
+
+ // Create addressof ops at the beginning of the function entry block
+ // and remap workgroup block arguments to the addressof results.
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
+ for (auto [idx, globalOp] : llvm::enumerate(workgroupGlobals)) {
+ auto addrOf = spirv::AddressOfOp::create(rewriter, loc, globalOp);
+ signatureConverter.remapInput(numFuncArgs + idx, addrOf.getResult());
+ }
+ }
+ }
+
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
&signatureConverter)))
return nullptr;
@@ -314,7 +364,10 @@ getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp,
if (!spirv::needsInterfaceVarABIAttrs(targetEnv))
return success();
- for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
+ // Only generate ABI attributes for regular function arguments, not for
+ // workgroup or private attributions which are handled separately.
+ for (auto argIndex :
+ llvm::seq<unsigned>(0, funcOp.getFunctionType().getNumInputs())) {
if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
argIndex, spirv::getInterfaceVarABIAttrName()))
return failure();
@@ -335,12 +388,19 @@ LogicalResult GPUFuncOpConversion::matchAndRewrite(
if (!gpu::GPUDialect::isKernel(funcOp))
return failure();
+ if (funcOp.getNumPrivateAttributions() > 0)
+ return funcOp.emitError(
+ "SPIR-V lowering of private attributions is not supported");
+
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
if (failed(
getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) {
argABI.clear();
- for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
+ // Only check ABI attributes for regular function arguments, not for
+ // workgroup attributions which don't have descriptor set/binding.
+ for (auto argIndex :
+ llvm::seq<unsigned>(0, funcOp.getFunctionType().getNumInputs())) {
// If the ABI is already specified, use it.
auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
argIndex, spirv::getInterfaceVarABIAttrName());
diff --git a/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir
index 7bf6f8419be0d..556c713111546 100644
--- a/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir
@@ -107,6 +107,26 @@ module attributes {gpu.container_module} {
// -----
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+ gpu.module @kernels {
+ // expected-error @below {{failed to legalize operation 'gpu.func'}}
+ // expected-error @below {{SPIR-V lowering of private attributions is not supported}}
+ gpu.func @private_attribution_unsupported(
+ %arg0: memref<256xf32, #spirv.storage_class<StorageBuffer>>)
+ workgroup(%wg: memref<256xf32, #spirv.storage_class<Workgroup>>)
+ private(%priv: memref<4xf32, #spirv.storage_class<Function>>)
+ kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [256, 1, 1]>} {
+ gpu.return
+ }
+ }
+}
+
+// -----
+
module attributes {gpu.container_module} {
gpu.module @kernels {
// CHECK-LABEL: spirv.func @barrier
@@ -128,3 +148,97 @@ module attributes {gpu.container_module} {
return
}
}
+
+// -----
+
+// Test gpu.func with a single workgroup attribution.
+module attributes {gpu.container_module} {
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Logical GLSL450 {
+ // CHECK-DAG: spirv.GlobalVariable @__workgroup_mem__kernel_wg_0 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x f32>)>, Workgroup>
+ // CHECK-LABEL: spirv.func @kernel_wg
+ // CHECK-SAME: {{%.*}}: f32 {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0), StorageBuffer>}
+ // CHECK-NOT: Workgroup
+ // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [256, 1, 1]>
+ // CHECK: spirv.mlir.addressof @__workgroup_mem__kernel_wg_0
+ // CHECK: spirv.Return
+ gpu.func @kernel_wg(%arg0 : f32)
+ workgroup(%wg : memref<256xf32, #spirv.storage_class<Workgroup>>)
+ kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [256, 1, 1]>} {
+ gpu.return
+ }
+ }
+
+ func.func @main() {
+ %0 = "op"() : () -> (f32)
+ %cst = arith.constant 1 : index
+ gpu.launch_func @kernels::@kernel_wg
+ blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
+ args(%0 : f32)
+ return
+ }
+}
+
+// -----
+
+// Test gpu.func with multiple workgroup attributions.
+module attributes {gpu.container_module} {
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Logical GLSL450 {
+ // CHECK-DAG: spirv.GlobalVariable @__workgroup_mem__kernel_multi_wg_0 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32>)>, Workgroup>
+ // CHECK-DAG: spirv.GlobalVariable @__workgroup_mem__kernel_multi_wg_1 : !spirv.ptr<!spirv.struct<(!spirv.array<64 x i32>)>, Workgroup>
+ // CHECK-LABEL: spirv.func @kernel_multi_wg
+ // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<256 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+ // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [128, 1, 1]>
+ // CHECK: spirv.mlir.addressof @__workgroup_mem__kernel_multi_wg_0
+ // CHECK: spirv.mlir.addressof @__workgroup_mem__kernel_multi_wg_1
+ // CHECK: spirv.Return
+ gpu.func @kernel_multi_wg(
+ %arg0: memref<256xf32, #spirv.storage_class<StorageBuffer>>)
+ workgroup(
+ %wg0: memref<128xf32, #spirv.storage_class<Workgroup>>,
+ %wg1: memref<64xi32, #spirv.storage_class<Workgroup>>)
+ kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [128, 1, 1]>} {
+ gpu.return
+ }
+ }
+
+ func.func @main() {
+ %0 = "op"() : () -> (memref<256xf32, #spirv.storage_class<StorageBuffer>>)
+ %cst = arith.constant 1 : index
+ gpu.launch_func @kernels::@kernel_multi_wg
+ blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
+ args(%0 : memref<256xf32, #spirv.storage_class<StorageBuffer>>)
+ return
+ }
+}
+
+// -----
+
+// Test gpu.func with workgroup attribution and barrier.
+module attributes {gpu.container_module} {
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Logical GLSL450 {
+ // CHECK-DAG: spirv.GlobalVariable @__workgroup_mem__kernel_wg_barrier_0 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x f32>)>, Workgroup>
+ // CHECK-LABEL: spirv.func @kernel_wg_barrier
+ // CHECK: spirv.mlir.addressof @__workgroup_mem__kernel_wg_barrier_0
+ // CHECK: spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
+ // CHECK: spirv.Return
+ gpu.func @kernel_wg_barrier(
+ %arg0: memref<256xf32, #spirv.storage_class<StorageBuffer>>)
+ workgroup(%wg: memref<256xf32, #spirv.storage_class<Workgroup>>)
+ kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [256, 1, 1]>} {
+ gpu.barrier
+ gpu.return
+ }
+ }
+
+ func.func @main() {
+ %0 = "op"() : () -> (memref<256xf32, #spirv.storage_class<StorageBuffer>>)
+ %cst = arith.constant 1 : index
+ gpu.launch_func @kernels::@kernel_wg_barrier
+ blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
+ args(%0 : memref<256xf32, #spirv.storage_class<StorageBuffer>>)
+ return
+ }
+}
diff --git a/mlir/test/Conversion/GPUToSPIRV/workgroup-memory.mlir b/mlir/test/Conversion/GPUToSPIRV/workgroup-memory.mlir
new file mode 100644
index 0000000000000..9bff6a6268691
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/workgroup-memory.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-gpu-to-spirv %s -o - | FileCheck %s
+
+// Test workgroup memory load/store through gpu.func workgroup attributions.
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+ // CHECK-LABEL: spirv.module @{{.*}} Logical GLSL450
+ gpu.module @kernels {
+ // CHECK-DAG: spirv.GlobalVariable @__workgroup_mem__load_store_workgroup_0 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x f32>)>, Workgroup>
+ // CHECK-LABEL: spirv.func @load_store_workgroup
+ // CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<256 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+ // CHECK-SAME: %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<256 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+ // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [256, 1, 1]>
+ gpu.func @load_store_workgroup(
+ %arg0: memref<256xf32, #spirv.storage_class<StorageBuffer>>,
+ %arg1: memref<256xf32, #spirv.storage_class<StorageBuffer>>)
+ workgroup(%smem: memref<256xf32, #spirv.storage_class<Workgroup>>)
+ kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [256, 1, 1]>} {
+ // CHECK: %[[WG:.*]] = spirv.mlir.addressof @__workgroup_mem__load_store_workgroup_0
+ // CHECK: spirv.AccessChain %[[ARG0]][{{%.*}}, {{%.*}}]
+ // CHECK: spirv.Load "StorageBuffer"
+ // CHECK: spirv.AccessChain %[[WG]][{{%.*}}, {{%.*}}]
+ // CHECK: spirv.Store "Workgroup"
+ // CHECK: spirv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
+ // CHECK: spirv.AccessChain %[[WG]][{{%.*}}, {{%.*}}]
+ // CHECK: spirv.Load "Workgroup"
+ // CHECK: spirv.AccessChain %[[ARG1]][{{%.*}}, {{%.*}}]
+ // CHECK: spirv.Store "StorageBuffer"
+ %c0 = arith.constant 0 : index
+ %val = memref.load %arg0[%c0] : memref<256xf32, #spirv.storage_class<StorageBuffer>>
+ memref.store %val, %smem[%c0] : memref<256xf32, #spirv.storage_class<Workgroup>>
+ gpu.barrier
+ %val2 = memref.load %smem[%c0] : memref<256xf32, #spirv.storage_class<Workgroup>>
+ memref.store %val2, %arg1[%c0] : memref<256xf32, #spirv.storage_class<StorageBuffer>>
+ gpu.return
+ }
+ }
+
+ func.func @main() {
+ %0 = "op"() : () -> (memref<256xf32, #spirv.storage_class<StorageBuffer>>)
+ %1 = "op"() : () -> (memref<256xf32, #spirv.storage_class<StorageBuffer>>)
+ %cst = arith.constant 1 : index
+ gpu.launch_func @kernels::@load_store_workgroup
+ blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
+ args(%0 : memref<256xf32, #spirv.storage_class<StorageBuffer>>,
+ %1 : memref<256xf32, #spirv.storage_class<StorageBuffer>>)
+ return
+ }
+}
+
+// -----
+
+// Test with 2D workgroup memory.
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+ gpu.module @kernels {
+ // 16 x 16 = 256 elements
+ // CHECK-DAG: spirv.GlobalVariable @__workgroup_mem__kernel_2d_wg_0 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x f32>)>, Workgroup>
+ // CHECK-LABEL: spirv.func @kernel_2d_wg
+ gpu.func @kernel_2d_wg(
+ %arg0: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>)
+ workgroup(%smem: memref<16x16xf32, #spirv.storage_class<Workgroup>>)
+ kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 16, 1]>} {
+ %c0 = arith.constant 0 : index
+ %val = memref.load %arg0[%c0, %c0] : memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
+ memref.store %val, %smem[%c0, %c0] : memref<16x16xf32, #spirv.storage_class<Workgroup>>
+ gpu.return
+ }
+ }
+}
More information about the Mlir-commits
mailing list