[Mlir-commits] [mlir] [mlir][spirv] Retain nontemporal attribute when converting memref load/store (PR #82119)
Artem Tyurin
llvmlistbot at llvm.org
Sat Feb 17 08:14:55 PST 2024
https://github.com/agentcooper created https://github.com/llvm/llvm-project/pull/82119
Fixes #77156.
>From c058eb4a564ea9f2deabe26e3691064e0a58b4a3 Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Sat, 17 Feb 2024 17:13:58 +0100
Subject: [PATCH] [mlir][spirv] Retain nontemporal attribute when converting
memref load/store
---
.../lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 11 ++++++++++-
.../Conversion/MemRefToSPIRV/memref-to-spirv.mlir | 14 ++++++++++++++
2 files changed, 24 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 57d8e894a24b0e..34318c612b4691 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
"Bad op type");
+ auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
+ if (nontemporalAttr && nontemporalAttr.getValue()) {
+ return std::pair{
+ spirv::MemoryAccessAttr::get(accessedPtr.getContext(),
+ spirv::MemoryAccess::Nontemporal),
+ IntegerAttr{}};
+ }
+
auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
spirv::attributeName<spirv::MemoryAccess>());
auto memrefAlignment =
@@ -623,7 +631,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
if (!loadPtr)
return failure();
- AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+ AlignmentRequirements requiredAlignment =
+ calculateRequiredAlignment(loadPtr, loadOp);
if (failed(requiredAlignment))
return rewriter.notifyMatchFailure(
loadOp, "failed to determine alignment requirements");
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index aa05fd9bc8ca89..e03b7bdf357dd5 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -431,3 +431,17 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
}
}
+
+// -----
+
+// Check nontemporal attribute
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+ func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
+ %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+// CHECK: spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
+ memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+// CHECK: spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
+ return
+ }
+}
More information about the Mlir-commits
mailing list