[Mlir-commits] [mlir] [mlir][spirv] Retain nontemporal attribute when converting memref load/store (PR #82119)

Artem Tyurin llvmlistbot at llvm.org
Tue Feb 27 23:38:51 PST 2024


https://github.com/agentcooper updated https://github.com/llvm/llvm-project/pull/82119

>From c058eb4a564ea9f2deabe26e3691064e0a58b4a3 Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Sat, 17 Feb 2024 17:13:58 +0100
Subject: [PATCH 1/5] [mlir][spirv] Retain nontemporal attribute when
 converting memref load/store

---
 .../lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 11 ++++++++++-
 .../Conversion/MemRefToSPIRV/memref-to-spirv.mlir  | 14 ++++++++++++++
 2 files changed, 24 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 57d8e894a24b0e..34318c612b4691 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
   assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
          "Bad op type");
 
+  auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
+  if (nontemporalAttr && nontemporalAttr.getValue()) {
+    return std::pair{
+        spirv::MemoryAccessAttr::get(accessedPtr.getContext(),
+                                     spirv::MemoryAccess::Nontemporal),
+        IntegerAttr{}};
+  }
+
   auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
       spirv::attributeName<spirv::MemoryAccess>());
   auto memrefAlignment =
@@ -623,7 +631,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   if (!loadPtr)
     return failure();
 
-  AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+  AlignmentRequirements requiredAlignment =
+      calculateRequiredAlignment(loadPtr, loadOp);
   if (failed(requiredAlignment))
     return rewriter.notifyMatchFailure(
         loadOp, "failed to determine alignment requirements");
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index aa05fd9bc8ca89..e03b7bdf357dd5 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -431,3 +431,17 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
 }
 
 }
+
+// -----
+
+// Check nontemporal attribute
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+  func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
+    %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+//       CHECK:  spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
+    memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+//       CHECK:  spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
+    return
+  }
+}

>From 7f46ec8d9794c67b844c8e0a0e2d64b9d924996a Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Mon, 26 Feb 2024 11:32:19 +0100
Subject: [PATCH 2/5] Set nontemporal only if alignment is not set

---
 .../MemRefToSPIRV/MemRefToSPIRV.cpp           | 77 ++++++++++---------
 .../MemRefToSPIRV/memref-to-spirv.mlir        | 19 ++++-
 2 files changed, 57 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 34318c612b4691..ec9d1261fef058 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -445,15 +445,23 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
 // LoadOp
 //===----------------------------------------------------------------------===//
 
-using AlignmentRequirements =
+using MemoryRequirements =
     FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
 
 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
 /// any.
-static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
+static MemoryRequirements calculateMemoryRequirements(Value accessedPtr,
+                                                      bool isNontemporal) {
+  MLIRContext *ctx = accessedPtr.getContext();
   auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
-  if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer)
+  if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
+    if (isNontemporal) {
+      return std::pair{
+          spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Nontemporal),
+          IntegerAttr{}};
+    }
     return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
+  }
 
   // PhysicalStorageBuffers require the `Aligned` attribute.
   auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
@@ -465,7 +473,6 @@ static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
   if (!sizeInBytes.has_value())
     return failure();
 
-  MLIRContext *ctx = accessedPtr.getContext();
   auto memAccessAttr =
       spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
   auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
@@ -475,28 +482,22 @@ static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
 /// Given an accessed SPIR-V pointer and the original memref load/store
 /// `memAccess` op, calculates the alignment requirements, if any. Takes into
 /// account the alignment attributes applied to the load/store op.
-static AlignmentRequirements
-calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
-  assert(memrefAccessOp);
-  assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
-         "Bad op type");
-
-  auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
-  if (nontemporalAttr && nontemporalAttr.getValue()) {
-    return std::pair{
-        spirv::MemoryAccessAttr::get(accessedPtr.getContext(),
-                                     spirv::MemoryAccess::Nontemporal),
-        IntegerAttr{}};
-  }
-
-  auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
+template <class LoadOrStoreOp>
+static MemoryRequirements
+calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
+  static_assert(
+      llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
+      "Must be called on either memref::LoadOp or memref::StoreOp");
+
+  Operation *op = loadOrStoreOp.getOperation();
+  auto memrefMemAccess = op->getAttrOfType<spirv::MemoryAccessAttr>(
       spirv::attributeName<spirv::MemoryAccess>());
-  auto memrefAlignment =
-      memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
+  auto memrefAlignment = op->getAttrOfType<IntegerAttr>("alignment");
   if (memrefMemAccess && memrefAlignment)
     return std::pair{memrefMemAccess, memrefAlignment};
 
-  return calculateRequiredAlignment(accessedPtr);
+  return calculateMemoryRequirements(accessedPtr,
+                                     loadOrStoreOp.getNontemporal());
 }
 
 LogicalResult
@@ -546,13 +547,13 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   // If the rewritten load op has the same bit width, use the loading value
   // directly.
   if (srcBits == dstBits) {
-    AlignmentRequirements alignmentRequirements =
-        calculateRequiredAlignment(accessChain, loadOp);
-    if (failed(alignmentRequirements))
+    MemoryRequirements memoryRequirements =
+        calculateMemoryRequirements(accessChain, loadOp);
+    if (failed(memoryRequirements))
       return rewriter.notifyMatchFailure(
-          loadOp, "failed to determine alignment requirements");
+          loadOp, "failed to determine memory requirements");
 
-    auto [memoryAccess, alignment] = *alignmentRequirements;
+    auto [memoryAccess, alignment] = *memoryRequirements;
     Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
                                                    memoryAccess, alignment);
     if (isBool)
@@ -576,13 +577,13 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   assert(accessChainOp.getIndices().size() == 2);
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
                                                    srcBits, dstBits, rewriter);
-  AlignmentRequirements alignmentRequirements =
-      calculateRequiredAlignment(adjustedPtr, loadOp);
-  if (failed(alignmentRequirements))
+  MemoryRequirements memoryRequirements =
+      calculateMemoryRequirements(adjustedPtr, loadOp);
+  if (failed(memoryRequirements))
     return rewriter.notifyMatchFailure(
-        loadOp, "failed to determine alignment requirements");
+        loadOp, "failed to determine memory requirements");
 
-  auto [memoryAccess, alignment] = *alignmentRequirements;
+  auto [memoryAccess, alignment] = *memoryRequirements;
   Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
                                                    memoryAccess, alignment);
 
@@ -631,8 +632,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   if (!loadPtr)
     return failure();
 
-  AlignmentRequirements requiredAlignment =
-      calculateRequiredAlignment(loadPtr, loadOp);
+  MemoryRequirements requiredAlignment =
+      calculateMemoryRequirements(loadPtr, loadOp);
   if (failed(requiredAlignment))
     return rewriter.notifyMatchFailure(
         loadOp, "failed to determine alignment requirements");
@@ -698,8 +699,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   assert(dstBits % srcBits == 0);
 
   if (srcBits == dstBits) {
-    AlignmentRequirements requiredAlignment =
-        calculateRequiredAlignment(accessChain);
+    MemoryRequirements requiredAlignment =
+        calculateMemoryRequirements(accessChain, storeOp);
     if (failed(requiredAlignment))
       return rewriter.notifyMatchFailure(
           storeOp, "failed to determine alignment requirements");
@@ -856,8 +857,8 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   if (!storePtr)
     return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
 
-  AlignmentRequirements requiredAlignment =
-      calculateRequiredAlignment(storePtr, storeOp);
+  MemoryRequirements requiredAlignment =
+      calculateMemoryRequirements(storePtr, storeOp);
   if (failed(requiredAlignment))
     return rewriter.notifyMatchFailure(
         storeOp, "failed to determine alignment requirements");
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index e03b7bdf357dd5..bef833967d1c6f 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -436,7 +436,15 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
 
 // Check nontemporal attribute
 
-module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [
+    Shader,
+    PhysicalStorageBufferAddresses
+  ], [
+    SPV_KHR_storage_buffer_storage_class,
+    SPV_KHR_physical_storage_buffer
+  ]>, #spirv.resource_limits<>>
+} {
   func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
     %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
 //       CHECK:  spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
@@ -444,4 +452,13 @@ module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader
 //       CHECK:  spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
     return
   }
+
+  // Nontemporal attribute is ignored in case of alignment
+  func.func @load_nontemporal_ignored(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+    %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+//       CHECK:  spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : f32
+    memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+//       CHECK:  spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : f32
+    return
+  }
 }

>From c6e2d8025e6b6b6d263370696eac76766ab56266 Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Tue, 27 Feb 2024 18:53:55 +0100
Subject: [PATCH 3/5] Allow Aligned and Nontemporal flags to coexist

---
 .../MemRefToSPIRV/MemRefToSPIRV.cpp           | 75 ++++++++++---------
 .../MemRefToSPIRV/memref-to-spirv.mlir        |  4 +-
 2 files changed, 41 insertions(+), 38 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index ec9d1261fef058..213674d6d40948 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -445,22 +445,29 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
 // LoadOp
 //===----------------------------------------------------------------------===//
 
-using MemoryRequirements =
-    FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
+struct MemoryRequirements {
+  spirv::MemoryAccessAttr memoryAccess;
+  IntegerAttr alignment;
+};
 
 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
 /// any.
-static MemoryRequirements calculateMemoryRequirements(Value accessedPtr,
-                                                      bool isNontemporal) {
+static FailureOr<MemoryRequirements>
+calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
   MLIRContext *ctx = accessedPtr.getContext();
+
+  auto memoryAccess = spirv::MemoryAccess::None;
+  if (isNontemporal) {
+    memoryAccess = memoryAccess | spirv::MemoryAccess::Nontemporal;
+  }
+
   auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
   if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
-    if (isNontemporal) {
-      return std::pair{
-          spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Nontemporal),
-          IntegerAttr{}};
+    if (memoryAccess == spirv::MemoryAccess::None) {
+      return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
     }
-    return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
+    return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
+                              IntegerAttr{}};
   }
 
   // PhysicalStorageBuffers require the `Aligned` attribute.
@@ -473,28 +480,29 @@ static MemoryRequirements calculateMemoryRequirements(Value accessedPtr,
   if (!sizeInBytes.has_value())
     return failure();
 
-  auto memAccessAttr =
-      spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
+  memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+  auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
   auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
-  return std::pair{memAccessAttr, alignment};
+  return MemoryRequirements{memAccessAttr, alignment};
 }
 
 /// Given an accessed SPIR-V pointer and the original memref load/store
 /// `memAccess` op, calculates the alignment requirements, if any. Takes into
 /// account the alignment attributes applied to the load/store op.
 template <class LoadOrStoreOp>
-static MemoryRequirements
+static FailureOr<MemoryRequirements>
 calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
   static_assert(
       llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
       "Must be called on either memref::LoadOp or memref::StoreOp");
 
-  Operation *op = loadOrStoreOp.getOperation();
-  auto memrefMemAccess = op->getAttrOfType<spirv::MemoryAccessAttr>(
+  Operation *memrefAccessOp = loadOrStoreOp.getOperation();
+  auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
       spirv::attributeName<spirv::MemoryAccess>());
-  auto memrefAlignment = op->getAttrOfType<IntegerAttr>("alignment");
+  auto memrefAlignment =
+      memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
   if (memrefMemAccess && memrefAlignment)
-    return std::pair{memrefMemAccess, memrefAlignment};
+    return MemoryRequirements{memrefMemAccess, memrefAlignment};
 
   return calculateMemoryRequirements(accessedPtr,
                                      loadOrStoreOp.getNontemporal());
@@ -547,8 +555,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   // If the rewritten load op has the same bit width, use the loading value
   // directly.
   if (srcBits == dstBits) {
-    MemoryRequirements memoryRequirements =
-        calculateMemoryRequirements(accessChain, loadOp);
+    auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
     if (failed(memoryRequirements))
       return rewriter.notifyMatchFailure(
           loadOp, "failed to determine memory requirements");
@@ -577,8 +584,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   assert(accessChainOp.getIndices().size() == 2);
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
                                                    srcBits, dstBits, rewriter);
-  MemoryRequirements memoryRequirements =
-      calculateMemoryRequirements(adjustedPtr, loadOp);
+  auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
   if (failed(memoryRequirements))
     return rewriter.notifyMatchFailure(
         loadOp, "failed to determine memory requirements");
@@ -632,14 +638,13 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   if (!loadPtr)
     return failure();
 
-  MemoryRequirements requiredAlignment =
-      calculateMemoryRequirements(loadPtr, loadOp);
-  if (failed(requiredAlignment))
+  auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
+  if (failed(memoryRequirements))
     return rewriter.notifyMatchFailure(
         loadOp, "failed to determine alignment requirements");
 
-  auto [memAccessAttr, alignment] = *requiredAlignment;
-  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memAccessAttr,
+  auto [memoryAccess, alignment] = *memoryRequirements;
+  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
                                              alignment);
   return success();
 }
@@ -699,18 +704,17 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   assert(dstBits % srcBits == 0);
 
   if (srcBits == dstBits) {
-    MemoryRequirements requiredAlignment =
-        calculateMemoryRequirements(accessChain, storeOp);
-    if (failed(requiredAlignment))
+    auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
+    if (failed(memoryRequirements))
       return rewriter.notifyMatchFailure(
           storeOp, "failed to determine alignment requirements");
 
-    auto [memAccessAttr, alignment] = *requiredAlignment;
+    auto [memoryAccess, alignment] = *memoryRequirements;
     Value storeVal = adaptor.getValue();
     if (isBool)
       storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
-                                                memAccessAttr, alignment);
+                                                memoryAccess, alignment);
     return success();
   }
 
@@ -857,15 +861,14 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   if (!storePtr)
     return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
 
-  MemoryRequirements requiredAlignment =
-      calculateMemoryRequirements(storePtr, storeOp);
-  if (failed(requiredAlignment))
+  auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
+  if (failed(memoryRequirements))
     return rewriter.notifyMatchFailure(
         storeOp, "failed to determine alignment requirements");
 
-  auto [memAccessAttr, alignment] = *requiredAlignment;
+  auto [memoryAccess, alignment] = *memoryRequirements;
   rewriter.replaceOpWithNewOp<spirv::StoreOp>(
-      storeOp, storePtr, adaptor.getValue(), memAccessAttr, alignment);
+      storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
   return success();
 }
 
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index bef833967d1c6f..21056a54fdc9dd 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -456,9 +456,9 @@ module attributes {
   // Nontemporal attribute is ignored in case of alignment
   func.func @load_nontemporal_ignored(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
     %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
-//       CHECK:  spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : f32
+//       CHECK:  spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned|Nontemporal", 4] : f32
     memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
-//       CHECK:  spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : f32
+//       CHECK:  spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned|Nontemporal", 4] : f32
     return
   }
 }

>From 4f13c4dfe380a3da6b1146dd2e165dd5a6ce707c Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Tue, 27 Feb 2024 18:58:40 +0100
Subject: [PATCH 4/5] Fix test description

---
 mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp     | 6 +++---
 mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir | 3 +--
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 213674d6d40948..302d38d20d408f 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -641,7 +641,7 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
   if (failed(memoryRequirements))
     return rewriter.notifyMatchFailure(
-        loadOp, "failed to determine alignment requirements");
+        loadOp, "failed to determine memory requirements");
 
   auto [memoryAccess, alignment] = *memoryRequirements;
   rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
@@ -707,7 +707,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
     auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
     if (failed(memoryRequirements))
       return rewriter.notifyMatchFailure(
-          storeOp, "failed to determine alignment requirements");
+          storeOp, "failed to determine memory requirements");
 
     auto [memoryAccess, alignment] = *memoryRequirements;
     Value storeVal = adaptor.getValue();
@@ -864,7 +864,7 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
   if (failed(memoryRequirements))
     return rewriter.notifyMatchFailure(
-        storeOp, "failed to determine alignment requirements");
+        storeOp, "failed to determine memory requirements");
 
   auto [memoryAccess, alignment] = *memoryRequirements;
   rewriter.replaceOpWithNewOp<spirv::StoreOp>(
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 21056a54fdc9dd..feb6d4e924015f 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -453,8 +453,7 @@ module attributes {
     return
   }
 
-  // Nontemporal attribute is ignored in case of alignment
-  func.func @load_nontemporal_ignored(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  func.func @load_nontemporal_aligned(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
     %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
 //       CHECK:  spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned|Nontemporal", 4] : f32
     memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>

>From 286e3f8dfe125b4e7d2d4d204cb8ee291970a587 Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Wed, 28 Feb 2024 08:38:36 +0100
Subject: [PATCH 5/5] Feedback

---
 mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 302d38d20d408f..0acb2142f3f68a 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -458,7 +458,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
 
   auto memoryAccess = spirv::MemoryAccess::None;
   if (isNontemporal) {
-    memoryAccess = memoryAccess | spirv::MemoryAccess::Nontemporal;
+    memoryAccess = spirv::MemoryAccess::Nontemporal;
   }
 
   auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());



More information about the Mlir-commits mailing list