[Mlir-commits] [mlir] [mlir][spirv] Retain nontemporal attribute when converting memref load/store (PR #82119)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 17 08:15:22 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Artem Tyurin (agentcooper)

<details>
<summary>Changes</summary>

Fixes #<!-- -->77156.

---
Full diff: https://github.com/llvm/llvm-project/pull/82119.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+10-1) 
- (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+14) 


``````````diff
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 57d8e894a24b0e..34318c612b4691 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
   assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
          "Bad op type");
 
+  auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
+  if (nontemporalAttr && nontemporalAttr.getValue()) {
+    return std::pair{
+        spirv::MemoryAccessAttr::get(accessedPtr.getContext(),
+                                     spirv::MemoryAccess::Nontemporal),
+        IntegerAttr{}};
+  }
+
   auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
       spirv::attributeName<spirv::MemoryAccess>());
   auto memrefAlignment =
@@ -623,7 +631,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   if (!loadPtr)
     return failure();
 
-  AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+  AlignmentRequirements requiredAlignment =
+      calculateRequiredAlignment(loadPtr, loadOp);
   if (failed(requiredAlignment))
     return rewriter.notifyMatchFailure(
         loadOp, "failed to determine alignment requirements");
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index aa05fd9bc8ca89..e03b7bdf357dd5 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -431,3 +431,17 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
 }
 
 }
+
+// -----
+
+// Check nontemporal attribute
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+  func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
+    %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+//       CHECK:  spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
+    memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+//       CHECK:  spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
+    return
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/82119


More information about the Mlir-commits mailing list