[Mlir-commits] [mlir] 6d231fb - [mlir] MemRefToSPIRV propagate alignment attributes from MemRef ops. (#151723)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 7 09:18:26 PDT 2025
Author: Erick Ochoa Lopez
Date: 2025-08-07T12:18:23-04:00
New Revision: 6d231fbb05417a77e8787f625fd14e1a30e27a5b
URL: https://github.com/llvm/llvm-project/commit/6d231fbb05417a77e8787f625fd14e1a30e27a5b
DIFF: https://github.com/llvm/llvm-project/commit/6d231fbb05417a77e8787f625fd14e1a30e27a5b.diff
LOG: [mlir] MemRefToSPIRV propagate alignment attributes from MemRef ops. (#151723)
This patchset:
* propagates alignment attributes from memref operations into the SPIR-V
dialect,
* fixes an error in the logic which previously propagated alignment
attributes but did not add other MemoryAccess attributes.
* adds a failure condition in the case where the alignment attribute
from the memref dialect (64-bit wide) does not fit in SPIR-V's alignment
attribute (specified to be 32-bit wide).
Added:
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index c7ecd8334da42..2e00b42f4a56d 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Visitors.h"
#include <cassert>
+#include <limits>
#include <optional>
#define DEBUG_TYPE "memref-to-spirv-pattern"
@@ -475,7 +476,12 @@ struct MemoryRequirements {
/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
/// any.
static FailureOr<MemoryRequirements>
-calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
+calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
+ uint64_t preferredAlignment) {
+ if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
+ return failure();
+ }
+
MLIRContext *ctx = accessedPtr.getContext();
auto memoryAccess = spirv::MemoryAccess::None;
@@ -484,7 +490,10 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
}
auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
- if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
+ bool mayOmitAlignment =
+ !preferredAlignment &&
+ ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
+ if (mayOmitAlignment) {
if (memoryAccess == spirv::MemoryAccess::None) {
return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
}
@@ -493,6 +502,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
}
// PhysicalStorageBuffers require the `Aligned` attribute.
+ // Other storage types may show an `Aligned` attribute.
auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
if (!pointeeType)
return failure();
@@ -504,7 +514,8 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
- auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
+ auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
+ auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
return MemoryRequirements{memAccessAttr, alignment};
}
@@ -518,16 +529,9 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
"Must be called on either memref::LoadOp or memref::StoreOp");
- Operation *memrefAccessOp = loadOrStoreOp.getOperation();
- auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
- spirv::attributeName<spirv::MemoryAccess>());
- auto memrefAlignment =
- memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
- if (memrefMemAccess && memrefAlignment)
- return MemoryRequirements{memrefMemAccess, memrefAlignment};
-
return calculateMemoryRequirements(accessedPtr,
- loadOrStoreOp.getNontemporal());
+ loadOrStoreOp.getNontemporal(),
+ loadOrStoreOp.getAlignment().value_or(0));
}
LogicalResult
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 2a7be0be7477a..e6321e99693ac 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -85,6 +85,28 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i :
return %0: i1
}
+// CHECK-LABEL: func @load_aligned
+// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %[[IDX:.+]]: index)
+func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 {
+ // CHECK: spirv.Load "StorageBuffer" {{.*}} ["Aligned", 32] : i8
+ %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class<StorageBuffer>>
+ return %0: i1
+}
+
+// CHECK-LABEL: func @load_aligned_nontemporal
+func.func @load_aligned_nontemporal(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 {
+ // CHECK: spirv.Load "StorageBuffer" {{.*}} ["Aligned|Nontemporal", 32] : i8
+ %0 = memref.load %src[%i] { alignment = 32, nontemporal = true } : memref<4xi1, #spirv.storage_class<StorageBuffer>>
+ return %0: i1
+}
+
+// CHECK-LABEL: func @load_aligned_psb
+func.func @load_aligned_psb(%src: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %i : index) -> i1 {
+ // CHECK: %[[VAL:.+]] = spirv.Load "PhysicalStorageBuffer" {{.*}} ["Aligned", 32] : i8
+ %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>
+ return %0: i1
+}
+
// CHECK-LABEL: func @store_i1
// CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>,
// CHECK-SAME: %[[IDX:.+]]: index
More information about the Mlir-commits
mailing list