[Mlir-commits] [mlir] [mlir] MemRefToSPIRV propagate alignment attributes from MemRef ops. (PR #151723)

Erick Ochoa Lopez llvmlistbot at llvm.org
Fri Aug 1 12:09:31 PDT 2025


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/151723

>From f4664aa5d2723babef4d11ebba6616a019ff3ed0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 1 Aug 2025 11:51:48 -0400
Subject: [PATCH 1/3] [mlir] Use memref's alignment attribute directly.

---
 mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 7a705336bf11c..0411589ed583d 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -511,8 +511,7 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
   Operation *memrefAccessOp = loadOrStoreOp.getOperation();
   auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
       spirv::attributeName<spirv::MemoryAccess>());
-  auto memrefAlignment =
-      memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
+  auto memrefAlignment = loadOrStoreOp.getAlignmentAttr();
   if (memrefMemAccess && memrefAlignment)
     return MemoryRequirements{memrefMemAccess, memrefAlignment};
 

>From 1c3455ce071a6a0de24bc32147bcc144827f8c87 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 1 Aug 2025 12:17:59 -0400
Subject: [PATCH 2/3] [mlir] MemRefToSPIRV propagate alignment attribute.

---
 .../Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp    |  9 ++++++---
 .../Conversion/MemRefToSPIRV/memref-to-spirv.mlir | 15 +++++++++++++++
 2 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 0411589ed583d..9b8e39fdd0335 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -465,7 +465,8 @@ struct MemoryRequirements {
 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
 /// any.
 static FailureOr<MemoryRequirements>
-calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
+calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
+                            uint64_t preferredAlignment) {
   MLIRContext *ctx = accessedPtr.getContext();
 
   auto memoryAccess = spirv::MemoryAccess::None;
@@ -494,7 +495,8 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
 
   memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
   auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
-  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
+  auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
+  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
   return MemoryRequirements{memAccessAttr, alignment};
 }
 
@@ -516,7 +518,8 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
     return MemoryRequirements{memrefMemAccess, memrefAlignment};
 
   return calculateMemoryRequirements(accessedPtr,
-                                     loadOrStoreOp.getNontemporal());
+                                     loadOrStoreOp.getNontemporal(),
+                                     loadOrStoreOp.getAlignment().value_or(0));
 }
 
 LogicalResult
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index d0ddac8cd801c..a00a6e0cbfe8a 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -85,6 +85,21 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i :
   return %0: i1
 }
 
+// CHECK-LABEL: func @load_aligned
+//  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %[[IDX:.+]]: index)
+func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %i : index) -> i1 {
+  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, PhysicalStorageBuffer>
+  // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
+  // CHECK: %[[VAL:.+]] = spirv.Load "PhysicalStorageBuffer" %[[ADDR]] ["Aligned", 32] : i8
+  // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+  // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
+  %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>
+  // CHECK: return %[[BOOL]]
+  return %0: i1
+}
+
 // CHECK-LABEL: func @store_i1
 //  CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>,
 //  CHECK-SAME: %[[IDX:.+]]: index

>From 526787f4fb8a87830db1061e8f781ce73e9f74f1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 1 Aug 2025 15:04:50 -0400
Subject: [PATCH 3/3] [mlir] Fix calculateMemoryRequirements in MemRefToSPIRV.

There was an early return in calculateMemoryRequirements
that looked explicitly for alignment and only set the alignment attribute.
However, this was not correct for the following reasons:

* Alignment was set only if both the alignment and the
  memory_access attributes were both present in the memref operation,
  without handling the case when only the alignment was exclusively
  present.
* In the case alignment and memory_access attributes were both present,
  the memory_access attribute would not be updated to aligned if
  the memory_access attribute was not marked aligned.
* In the case alignment and memory_access attributes were both present,
  other memory requirements (e.g., non_temporal) would not be added
  as attributes.
---
 .../Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp  | 13 +++++--------
 .../MemRefToSPIRV/memref-to-spirv.mlir          | 17 ++++++++++++++++-
 2 files changed, 21 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 9b8e39fdd0335..2204cacf959ce 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -475,7 +475,10 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
   }
 
   auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
-  if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
+  bool mayOmitAlignment =
+      !preferredAlignment &&
+      ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
+  if (mayOmitAlignment) {
     if (memoryAccess == spirv::MemoryAccess::None) {
       return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
     }
@@ -484,6 +487,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
   }
 
   // PhysicalStorageBuffers require the `Aligned` attribute.
+  // Other storage types may show an `Aligned` attribute.
   auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
   if (!pointeeType)
     return failure();
@@ -510,13 +514,6 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
       llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
       "Must be called on either memref::LoadOp or memref::StoreOp");
 
-  Operation *memrefAccessOp = loadOrStoreOp.getOperation();
-  auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
-      spirv::attributeName<spirv::MemoryAccess>());
-  auto memrefAlignment = loadOrStoreOp.getAlignmentAttr();
-  if (memrefMemAccess && memrefAlignment)
-    return MemoryRequirements{memrefMemAccess, memrefAlignment};
-
   return calculateMemoryRequirements(accessedPtr,
                                      loadOrStoreOp.getNontemporal(),
                                      loadOrStoreOp.getAlignment().value_or(0));
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index a00a6e0cbfe8a..95c7349476230 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -86,8 +86,23 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i :
 }
 
 // CHECK-LABEL: func @load_aligned
+//  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %[[IDX:.+]]: index)
+func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 {
+  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer>
+  // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
+  // CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] ["Aligned", 32] : i8
+  // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+  // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
+  %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class<StorageBuffer>>
+  // CHECK: return %[[BOOL]]
+  return %0: i1
+}
+
+// CHECK-LABEL: func @load_aligned_psb
 //  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %[[IDX:.+]]: index)
-func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %i : index) -> i1 {
+func.func @load_aligned_psb(%src: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %i : index) -> i1 {
   // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, PhysicalStorageBuffer>
   // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
   // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32



More information about the Mlir-commits mailing list