[Mlir-commits] [mlir] [mlir][spirv] Allow complex element types in memref allocation checks (PR #175836)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 13 13:03:11 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Eric Feng (efric)

<details>
<summary>Changes</summary>

Support for complex element types in SPIR-V was introduced in 97f3bb73a29a566e99e33ae4338c2c3d9957e561, and memref type conversion was updated accordingly. However, the element type check used during memref allocation pattern matching in the SPIR-V lowering was not updated to recognize complex element types. This patch resolves this consistency. 

Fixes: https://github.com/iree-org/iree/issues/23117

---
Full diff: https://github.com/llvm/llvm-project/pull/175836.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+4-2) 
- (modified) mlir/test/Conversion/MemRefToSPIRV/alloc.mlir (+26) 


``````````diff
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index a90dcc8cc3ef1..42e082f69e475 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -134,14 +134,16 @@ static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
     return false;
   }
 
-  // Currently only support static shape and int or float or vector of int or
-  // float element type.
+  // Currently only support static shape and int or float, complex of int or
+  // float, or vector of int or float element type.
   if (!type.hasStaticShape())
     return false;
 
   Type elementType = type.getElementType();
   if (auto vecType = dyn_cast<VectorType>(elementType))
     elementType = vecType.getElementType();
+  if (auto compType = dyn_cast<ComplexType>(elementType))
+    elementType = compType.getElementType();
   return elementType.isIntOrFloat();
 }
 
diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index f4fbfdf01196a..bb71557faef41 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -195,3 +195,29 @@ module attributes {
 
 // Zero-sized allocations are not handled yet. Just make sure we do not crash.
 // CHECK-LABEL: func @zero_size()
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+  }
+{
+  func.func @alloc_dealloc_workgroup_mem_complex_f32(%arg0 : index, %arg1 : index) {
+    %0 = memref.alloc() : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+    memref.store %1, %0[%arg0, %arg1] : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+    memref.dealloc %0 : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+    return
+  }
+}
+
+//       CHECK: spirv.GlobalVariable @[[$VAR:.+]] : !spirv.ptr<!spirv.struct<(!spirv.array<20 x vector<2xf32>>)>, Workgroup>
+// CHECK-LABEL: func @alloc_dealloc_workgroup_mem_complex_f32
+//   CHECK-NOT:   memref.alloc
+//       CHECK:   %[[PTR:.+]] = spirv.mlir.addressof @[[$VAR]]
+//       CHECK:   %[[LOADPTR:.+]] = spirv.AccessChain %[[PTR]]
+//       CHECK:   %[[VAL:.+]] = spirv.Load "Workgroup" %[[LOADPTR]] : vector<2xf32>
+//       CHECK:   %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]]
+//       CHECK:   spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : vector<2xf32>
+//   CHECK-NOT:   memref.dealloc

``````````

</details>


https://github.com/llvm/llvm-project/pull/175836


More information about the Mlir-commits mailing list