[Mlir-commits] [mlir] d11ef6a - [mlir][spirv] Allow complex element types in memref allocation checks (#175836)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 13 13:51:43 PST 2026
Author: Eric Feng
Date: 2026-01-13T16:51:39-05:00
New Revision: d11ef6a2b86a70ac7f20c7b164fbe6faab0dd8a1
URL: https://github.com/llvm/llvm-project/commit/d11ef6a2b86a70ac7f20c7b164fbe6faab0dd8a1
DIFF: https://github.com/llvm/llvm-project/commit/d11ef6a2b86a70ac7f20c7b164fbe6faab0dd8a1.diff
LOG: [mlir][spirv] Allow complex element types in memref allocation checks (#175836)
Support for complex types in SPIR-V was introduced in
97f3bb73a29a566e99e33ae4338c2c3d9957e561, and memref type conversion was
updated accordingly to include them. However, the element type precheck
used during memref alloc/dealloc pattern matching in the SPIR-V lowering
was not updated to recognize complex element types. This patch resolves
this inconsistency.
Fixes: https://github.com/iree-org/iree/issues/23117
---------
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
Added:
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index a90dcc8cc3ef1..42e082f69e475 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -134,14 +134,16 @@ static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
return false;
}
- // Currently only support static shape and int or float or vector of int or
- // float element type.
+ // Currently only support static shape and int or float, complex of int or
+ // float, or vector of int or float element type.
if (!type.hasStaticShape())
return false;
Type elementType = type.getElementType();
if (auto vecType = dyn_cast<VectorType>(elementType))
elementType = vecType.getElementType();
+ if (auto compType = dyn_cast<ComplexType>(elementType))
+ elementType = compType.getElementType();
return elementType.isIntOrFloat();
}
diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index f4fbfdf01196a..bb71557faef41 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -195,3 +195,29 @@ module attributes {
// Zero-sized allocations are not handled yet. Just make sure we do not crash.
// CHECK-LABEL: func @zero_size()
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+ }
+{
+ func.func @alloc_dealloc_workgroup_mem_complex_f32(%arg0 : index, %arg1 : index) {
+ %0 = memref.alloc() : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+ %1 = memref.load %0[%arg0, %arg1] : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+ memref.store %1, %0[%arg0, %arg1] : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+ memref.dealloc %0 : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+ return
+ }
+}
+
+// CHECK: spirv.GlobalVariable @[[$VAR:.+]] : !spirv.ptr<!spirv.struct<(!spirv.array<20 x vector<2xf32>>)>, Workgroup>
+// CHECK-LABEL: func @alloc_dealloc_workgroup_mem_complex_f32
+// CHECK-NOT: memref.alloc
+// CHECK: %[[PTR:.+]] = spirv.mlir.addressof @[[$VAR]]
+// CHECK: %[[LOADPTR:.+]] = spirv.AccessChain %[[PTR]]
+// CHECK: %[[VAL:.+]] = spirv.Load "Workgroup" %[[LOADPTR]] : vector<2xf32>
+// CHECK: %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]]
+// CHECK: spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : vector<2xf32>
+// CHECK-NOT: memref.dealloc
More information about the Mlir-commits
mailing list