[Mlir-commits] [mlir] 64c4e52 - [mlir][SPIRV] Add alignment calculation to support `PhysicalStorageBuffer` with vector types (#187698)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 20 23:16:02 PDT 2026
Author: Artem Gindinson
Date: 2026-03-21T07:15:57+01:00
New Revision: 64c4e529a9560b77a049fe80f635d9a29a7b515e
URL: https://github.com/llvm/llvm-project/commit/64c4e529a9560b77a049fe80f635d9a29a7b515e
DIFF: https://github.com/llvm/llvm-project/commit/64c4e529a9560b77a049fe80f635d9a29a7b515e.diff
LOG: [mlir][SPIRV] Add alignment calculation to support `PhysicalStorageBuffer` with vector types (#187698)
This allows to lower `memref.load`/`store` operations on
`PhysicalStorageBuffer`-typed resources with the underlying type being a
vector type. This improves support for the `PhysicalStorageBuffer`
capability in pipelines that use the Vector dialect for distribution.
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
Added:
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 565dee6f27589..3b2aab6c9a824 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -630,12 +630,21 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
// PhysicalStorageBuffers require the `Aligned` attribute.
// Other storage types may show an `Aligned` attribute.
- auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
- if (!pointeeType)
- return failure();
+ std::optional<int64_t> sizeInBytes;
+ Type rawPointeeType = ptrType.getPointeeType();
+ if (auto scalarType = dyn_cast<spirv::ScalarType>(rawPointeeType)) {
+ // For scalar types, the alignment is determined by their size.
+ sizeInBytes = scalarType.getSizeInBytes();
+ } else if (auto vecType = dyn_cast<VectorType>(rawPointeeType)) {
+ // For vector element types, the alignment should equal the total size of
+ // the vector.
+ if (auto scalarElem =
+ dyn_cast<spirv::ScalarType>(vecType.getElementType())) {
+ if (auto elemSize = scalarElem.getSizeInBytes())
+ sizeInBytes = *elemSize * vecType.getNumElements();
+ }
+ }
- // For scalar types, the alignment is determined by their size.
- std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
if (!sizeInBytes.has_value())
return failure();
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index ab3c8b7397e1a..931dd43be33c3 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -185,6 +185,37 @@ func.func @load_store_f16_physical(%arg0: memref<f16, #spirv.storage_class<Physi
return
}
+// CHECK-LABEL: @load_store_vec4f16_physical
+func.func @load_store_vec4f16_physical(%arg0: memref<5xvector<4xf16>, #spirv.storage_class<PhysicalStorageBuffer>>, %i: index) {
+ // Alignment = 4 elements * 2 bytes = 8 bytes
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 8] : vector<4xf16>
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 8] : vector<4xf16>
+ %0 = memref.load %arg0[%i] : memref<5xvector<4xf16>, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[%i] : memref<5xvector<4xf16>, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_vec4f32_physical
+func.func @load_store_vec4f32_physical(%arg0: memref<8xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>, %i: index) {
+ // Alignment = 4 elements * 4 bytes = 16 bytes
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 16] : vector<4xf32>
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 16] : vector<4xf32>
+ %0 = memref.load %arg0[%i] : memref<8xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[%i] : memref<8xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_vec4f32_dynamic_physical
+// CHECK-SAME: (%[[ARG0:.+]]: memref<?xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>,
+func.func @load_store_vec4f32_dynamic_physical(%arg0: memref<?xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>, %i: index) {
+ // CHECK: builtin.unrealized_conversion_cast %[[ARG0]] {{.*}} to !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16>
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 16] : vector<4xf32>
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 16] : vector<4xf32>
+ %0 = memref.load %arg0[%i] : memref<?xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[%i] : memref<?xvector<4xf32>, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
} // end module
// -----
More information about the Mlir-commits
mailing list