[Mlir-commits] [mlir] [mlir][SPIRV] Add alignment calculation to support `PhysicalStorageBuffer` with vector types (PR #187698)

Artem Gindinson llvmlistbot at llvm.org
Fri Mar 20 07:00:54 PDT 2026


https://github.com/AGindinson created https://github.com/llvm/llvm-project/pull/187698

This allows to lower `memref.load`/`store` operations on `PhysicalStorageBuffer`-typed resources with the underlying type being a vector type. This improves support for the `PhysicalStorageBuffer` capability in pipelines that use the Vector dialect for distribution.

>From 8fded1cf69b8815914ea7fd62eaa473bd2118505 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Fri, 20 Mar 2026 10:27:17 +0000
Subject: [PATCH] [mlir][SPIRV] Add alignment calculation to support
 `PhysicalStorageBuffer` with vector types

This allows to lower `memref.load`/`store` operations on `PhysicalStorageBuffer`-typed resources
with the underlying type being a vector type. This improves support for the `PhysicalStorageBuffer`
capability in pipelines that use the Vector dialect for distribution.

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 .../MemRefToSPIRV/MemRefToSPIRV.cpp           | 19 +++++++++---
 .../MemRefToSPIRV/memref-to-spirv.mlir        | 31 +++++++++++++++++++
 2 files changed, 45 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 565dee6f27589..3b2aab6c9a824 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -630,12 +630,21 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
 
   // PhysicalStorageBuffers require the `Aligned` attribute.
   // Other storage types may show an `Aligned` attribute.
-  auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
-  if (!pointeeType)
-    return failure();
+  std::optional<int64_t> sizeInBytes;
+  Type rawPointeeType = ptrType.getPointeeType();
+  if (auto scalarType = dyn_cast<spirv::ScalarType>(rawPointeeType)) {
+    // For scalar types, the alignment is determined by their size.
+    sizeInBytes = scalarType.getSizeInBytes();
+  } else if (auto vecType = dyn_cast<VectorType>(rawPointeeType)) {
+    // For vector element types, the alignment should equal the total size of
+    // the vector.
+    if (auto scalarElem =
+            dyn_cast<spirv::ScalarType>(vecType.getElementType())) {
+      if (auto elemSize = scalarElem.getSizeInBytes())
+        sizeInBytes = *elemSize * vecType.getNumElements();
+    }
+  }
 
-  // For scalar types, the alignment is determined by their size.
-  std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
   if (!sizeInBytes.has_value())
     return failure();
 
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index ab3c8b7397e1a..931dd43be33c3 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -185,6 +185,37 @@ func.func @load_store_f16_physical(%arg0: memref<f16, #spirv.storage_class<Physi
   return
 }
 
+// CHECK-LABEL: @load_store_vec4f16_physical
+func.func @load_store_vec4f16_physical(%arg0: memref<5xvector<4xf16>, #spirv.storage_class<PhysicalStorageBuffer>>, %i: index) {
+  // Alignment = 4 elements * 2 bytes = 8 bytes
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 8] : vector<4xf16>
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 8] : vector<4xf16>
+  %0 = memref.load %arg0[%i] : memref<5xvector<4xf16>, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[%i] : memref<5xvector<4xf16>, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_vec4f32_physical
+func.func @load_store_vec4f32_physical(%arg0: memref<8xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>, %i: index) {
+  // Alignment = 4 elements * 4 bytes = 16 bytes
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 16] : vector<4xf32>
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 16] : vector<4xf32>
+  %0 = memref.load %arg0[%i] : memref<8xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[%i] : memref<8xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_vec4f32_dynamic_physical
+//  CHECK-SAME: (%[[ARG0:.+]]: memref<?xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>,
+func.func @load_store_vec4f32_dynamic_physical(%arg0: memref<?xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>, %i: index) {
+  //     CHECK: builtin.unrealized_conversion_cast %[[ARG0]] {{.*}} to !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16>
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 16] : vector<4xf32>
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 16] : vector<4xf32>
+  %0 = memref.load %arg0[%i] : memref<?xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[%i] : memref<?xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
 } // end module
 
 // -----



More information about the Mlir-commits mailing list