[Mlir-commits] [mlir] [mlir][spirv] Allow complex element types in memref allocation checks (PR #175836)

Eric Feng llvmlistbot at llvm.org
Tue Jan 13 13:02:41 PST 2026


https://github.com/efric created https://github.com/llvm/llvm-project/pull/175836

Support for complex element types in SPIR-V was introduced in 97f3bb73a29a566e99e33ae4338c2c3d9957e561, and memref type conversion was updated accordingly. However, the element type check used during memref allocation pattern matching in the SPIR-V lowering was not updated to recognize complex element types. This patch resolves this consistency. 

Fixes: https://github.com/iree-org/iree/issues/23117

>From a00e8e17e237d5d57a082bea5ec314089fa61e0a Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Tue, 13 Jan 2026 12:44:35 -0800
Subject: [PATCH 1/3] fix spirv issue

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 .../MemRefToSPIRV/MemRefToSPIRV.cpp           |  6 +++--
 mlir/test/Conversion/MemRefToSPIRV/alloc.mlir | 26 +++++++++++++++++++
 2 files changed, 30 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index a90dcc8cc3ef1..42e082f69e475 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -134,14 +134,16 @@ static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
     return false;
   }
 
-  // Currently only support static shape and int or float or vector of int or
-  // float element type.
+  // Currently only support static shape and int or float, complex of int or
+  // float, or vector of int or float element type.
   if (!type.hasStaticShape())
     return false;
 
   Type elementType = type.getElementType();
   if (auto vecType = dyn_cast<VectorType>(elementType))
     elementType = vecType.getElementType();
+  if (auto compType = dyn_cast<ComplexType>(elementType))
+    elementType = compType.getElementType();
   return elementType.isIntOrFloat();
 }
 
diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index f4fbfdf01196a..06d53c5a8fba2 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -195,3 +195,29 @@ module attributes {
 
 // Zero-sized allocations are not handled yet. Just make sure we do not crash.
 // CHECK-LABEL: func @zero_size()
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+  }
+{
+  func.func @alloc_workgroup_mem_complex_f32(%arg0 : index, %arg1 : index) {
+    %0 = memref.alloc() : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+    memref.store %1, %0[%arg0, %arg1] : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+    memref.dealloc %0 : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
+    return
+  }
+}
+
+//       CHECK: spirv.GlobalVariable @[[$VAR:.+]] : !spirv.ptr<!spirv.struct<(!spirv.array<20 x vector<2xf32>>)>, Workgroup>
+// CHECK-LABEL: func @alloc_workgroup_mem_complex_f32
+//   CHECK-NOT:   memref.alloc
+//       CHECK:   %[[PTR:.+]] = spirv.mlir.addressof @[[$VAR]]
+//       CHECK:   %[[LOADPTR:.+]] = spirv.AccessChain %[[PTR]]
+//       CHECK:   %[[VAL:.+]] = spirv.Load "Workgroup" %[[LOADPTR]] : vector<2xf32>
+//       CHECK:   %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]]
+//       CHECK:   spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : vector<2xf32>
+//   CHECK-NOT:   memref.dealloc
\ No newline at end of file

>From 3aef23dc119bc8c2d5f37f7e56e86faa8e865adf Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Tue, 13 Jan 2026 12:50:24 -0800
Subject: [PATCH 2/3] newline

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/test/Conversion/MemRefToSPIRV/alloc.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index 06d53c5a8fba2..7b041211c8be4 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -220,4 +220,4 @@ module attributes {
 //       CHECK:   %[[VAL:.+]] = spirv.Load "Workgroup" %[[LOADPTR]] : vector<2xf32>
 //       CHECK:   %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]]
 //       CHECK:   spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : vector<2xf32>
-//   CHECK-NOT:   memref.dealloc
\ No newline at end of file
+//   CHECK-NOT:   memref.dealloc

>From a9d7426d2ee222d12e027467b1a10298d72176cf Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Tue, 13 Jan 2026 12:51:26 -0800
Subject: [PATCH 3/3] nit

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/test/Conversion/MemRefToSPIRV/alloc.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index 7b041211c8be4..bb71557faef41 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -203,7 +203,7 @@ module attributes {
     #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
   }
 {
-  func.func @alloc_workgroup_mem_complex_f32(%arg0 : index, %arg1 : index) {
+  func.func @alloc_dealloc_workgroup_mem_complex_f32(%arg0 : index, %arg1 : index) {
     %0 = memref.alloc() : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
     %1 = memref.load %0[%arg0, %arg1] : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
     memref.store %1, %0[%arg0, %arg1] : memref<4x5xcomplex<f32>, #spirv.storage_class<Workgroup>>
@@ -213,7 +213,7 @@ module attributes {
 }
 
 //       CHECK: spirv.GlobalVariable @[[$VAR:.+]] : !spirv.ptr<!spirv.struct<(!spirv.array<20 x vector<2xf32>>)>, Workgroup>
-// CHECK-LABEL: func @alloc_workgroup_mem_complex_f32
+// CHECK-LABEL: func @alloc_dealloc_workgroup_mem_complex_f32
 //   CHECK-NOT:   memref.alloc
 //       CHECK:   %[[PTR:.+]] = spirv.mlir.addressof @[[$VAR]]
 //       CHECK:   %[[LOADPTR:.+]] = spirv.AccessChain %[[PTR]]



More information about the Mlir-commits mailing list