[Mlir-commits] [mlir] [mlir][spirv] Retain nontemporal attribute when converting memref load/store (PR #82119)

Artem Tyurin llvmlistbot at llvm.org
Sat Feb 17 08:14:55 PST 2024


https://github.com/agentcooper created https://github.com/llvm/llvm-project/pull/82119

Fixes #77156.

>From c058eb4a564ea9f2deabe26e3691064e0a58b4a3 Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Sat, 17 Feb 2024 17:13:58 +0100
Subject: [PATCH] [mlir][spirv] Retain nontemporal attribute when converting
 memref load/store

---
 .../lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 11 ++++++++++-
 .../Conversion/MemRefToSPIRV/memref-to-spirv.mlir  | 14 ++++++++++++++
 2 files changed, 24 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 57d8e894a24b0e..34318c612b4691 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
   assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
          "Bad op type");
 
+  auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
+  if (nontemporalAttr && nontemporalAttr.getValue()) {
+    return std::pair{
+        spirv::MemoryAccessAttr::get(accessedPtr.getContext(),
+                                     spirv::MemoryAccess::Nontemporal),
+        IntegerAttr{}};
+  }
+
   auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
       spirv::attributeName<spirv::MemoryAccess>());
   auto memrefAlignment =
@@ -623,7 +631,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   if (!loadPtr)
     return failure();
 
-  AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+  AlignmentRequirements requiredAlignment =
+      calculateRequiredAlignment(loadPtr, loadOp);
   if (failed(requiredAlignment))
     return rewriter.notifyMatchFailure(
         loadOp, "failed to determine alignment requirements");
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index aa05fd9bc8ca89..e03b7bdf357dd5 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -431,3 +431,17 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
 }
 
 }
+
+// -----
+
+// Check nontemporal attribute
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+  func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
+    %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+//       CHECK:  spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
+    memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+//       CHECK:  spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
+    return
+  }
+}



More information about the Mlir-commits mailing list