[Mlir-commits] [mlir] [mlir][spirv] Allow complex element types in memref allocation checks (PR #175836)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 13 13:03:11 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Eric Feng (efric)
<details>
<summary>Changes</summary>
Support for complex element types in SPIR-V was introduced in 97f3bb73a29a566e99e33ae4338c2c3d9957e561, and memref type conversion was updated accordingly. However, the element type check used during memref allocation pattern matching in the SPIR-V lowering was not updated to recognize complex element types. This patch resolves this consistency.
Fixes: https://github.com/iree-org/iree/issues/23117
---
Full diff: https://github.com/llvm/llvm-project/pull/175836.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+4-2)
- (modified) mlir/test/Conversion/MemRefToSPIRV/alloc.mlir (+26)
``````````diff
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index a90dcc8cc3ef1..42e082f69e475 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -134,14 +134,16 @@ static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
return false;
}
- // Currently only support static shape and int or float or vector of int or
- // float element type.
+ // Currently only support static shape and int or float, complex of int or
+ // float, or vector of int or float element type.
if (!type.hasStaticShape())
return false;
Type elementType = type.getElementType();
if (auto vecType = dyn_cast<VectorType>(elementType))
elementType = vecType.getElementType();
+ if (auto compType = dyn_cast<ComplexType>(elementType))
+ elementType = compType.getElementType();
return elementType.isIntOrFloat();
}
diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index f4fbfdf01196a..bb71557faef41 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -195,3 +195,29 @@ module attributes {
// Zero-sized allocations are not handled yet. Just make sure we do not crash.
// CHECK-LABEL: func @zero_size()
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+ }
+{
+ func.func @alloc_dealloc_workgroup_mem_complex_f32(%arg0 : index, %arg1 : index) {
+ %0 = memref.alloc() : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+ %1 = memref.load %0[%arg0, %arg1] : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+ memref.store %1, %0[%arg0, %arg1] : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+ memref.dealloc %0 : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+ return
+ }
+}
+
+// CHECK: spirv.GlobalVariable @[[$VAR:.+]] : !spirv.ptr<!spirv.struct<(!spirv.array<20 x vector<2xf32>>)>, Workgroup>
+// CHECK-LABEL: func @alloc_dealloc_workgroup_mem_complex_f32
+// CHECK-NOT: memref.alloc
+// CHECK: %[[PTR:.+]] = spirv.mlir.addressof @[[$VAR]]
+// CHECK: %[[LOADPTR:.+]] = spirv.AccessChain %[[PTR]]
+// CHECK: %[[VAL:.+]] = spirv.Load "Workgroup" %[[LOADPTR]] : vector<2xf32>
+// CHECK: %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]]
+// CHECK: spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : vector<2xf32>
+// CHECK-NOT: memref.dealloc
``````````
</details>
https://github.com/llvm/llvm-project/pull/175836
More information about the Mlir-commits
mailing list