[Mlir-commits] [mlir] [mlir] MemRefToSPIRV propagate alignment attributes from MemRef ops. (PR #151723)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 1 12:39:10 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Erick Ochoa Lopez (amd-eochoalo)

<details>
<summary>Changes</summary>

Propagating alignment attributes from memref operations into the SPIR-V dialect. 

---
Full diff: https://github.com/llvm/llvm-project/pull/151723.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+17-12) 
- (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+45) 


``````````diff
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 7a705336bf11c..e730998f153b0 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Visitors.h"
 #include <cassert>
+#include <limits>
 #include <optional>
 
 #define DEBUG_TYPE "memref-to-spirv-pattern"
@@ -465,7 +466,13 @@ struct MemoryRequirements {
 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
 /// any.
 static FailureOr<MemoryRequirements>
-calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
+calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
+                            uint64_t preferredAlignment) {
+
+  if (std::numeric_limits<uint32_t>::max() < preferredAlignment) {
+    return failure();
+  }
+
   MLIRContext *ctx = accessedPtr.getContext();
 
   auto memoryAccess = spirv::MemoryAccess::None;
@@ -474,7 +481,10 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
   }
 
   auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
-  if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
+  bool mayOmitAlignment =
+      !preferredAlignment &&
+      ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
+  if (mayOmitAlignment) {
     if (memoryAccess == spirv::MemoryAccess::None) {
       return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
     }
@@ -483,6 +493,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
   }
 
   // PhysicalStorageBuffers require the `Aligned` attribute.
+  // Other storage types may show an `Aligned` attribute.
   auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
   if (!pointeeType)
     return failure();
@@ -494,7 +505,8 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
 
   memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
   auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
-  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
+  auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
+  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
   return MemoryRequirements{memAccessAttr, alignment};
 }
 
@@ -508,16 +520,9 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
       llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
       "Must be called on either memref::LoadOp or memref::StoreOp");
 
-  Operation *memrefAccessOp = loadOrStoreOp.getOperation();
-  auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
-      spirv::attributeName<spirv::MemoryAccess>());
-  auto memrefAlignment =
-      memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
-  if (memrefMemAccess && memrefAlignment)
-    return MemoryRequirements{memrefMemAccess, memrefAlignment};
-
   return calculateMemoryRequirements(accessedPtr,
-                                     loadOrStoreOp.getNontemporal());
+                                     loadOrStoreOp.getNontemporal(),
+                                     loadOrStoreOp.getAlignment().value_or(0));
 }
 
 LogicalResult
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index d0ddac8cd801c..7c765f70136bb 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -85,6 +85,51 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i :
   return %0: i1
 }
 
+// CHECK-LABEL: func @load_aligned
+//  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %[[IDX:.+]]: index)
+func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 {
+  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer>
+  // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
+  // CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] ["Aligned", 32] : i8
+  // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+  // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
+  %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class<StorageBuffer>>
+  // CHECK: return %[[BOOL]]
+  return %0: i1
+}
+
+// CHECK-LABEL: func @load_aligned_nontemporal
+//  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %[[IDX:.+]]: index)
+func.func @load_aligned_nontemporal(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 {
+  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer>
+  // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
+  // CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] ["Aligned|Nontemporal", 32] : i8
+  // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+  // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
+  %0 = memref.load %src[%i] { alignment = 32, nontemporal = true } : memref<4xi1, #spirv.storage_class<StorageBuffer>>
+  // CHECK: return %[[BOOL]]
+  return %0: i1
+}
+
+// CHECK-LABEL: func @load_aligned_psb
+//  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %[[IDX:.+]]: index)
+func.func @load_aligned_psb(%src: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %i : index) -> i1 {
+  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, PhysicalStorageBuffer>
+  // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
+  // CHECK: %[[VAL:.+]] = spirv.Load "PhysicalStorageBuffer" %[[ADDR]] ["Aligned", 32] : i8
+  // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+  // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
+  %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>
+  // CHECK: return %[[BOOL]]
+  return %0: i1
+}
+
 // CHECK-LABEL: func @store_i1
 //  CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>,
 //  CHECK-SAME: %[[IDX:.+]]: index

``````````

</details>


https://github.com/llvm/llvm-project/pull/151723


More information about the Mlir-commits mailing list