[Mlir-commits] [mlir] [mlir][spirv] Retain nontemporal attribute when converting memref load/store (PR #82119)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Feb 17 08:15:23 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Artem Tyurin (agentcooper)
<details>
<summary>Changes</summary>
Fixes #<!-- -->77156.
---
Full diff: https://github.com/llvm/llvm-project/pull/82119.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+10-1)
- (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+14)
``````````diff
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 57d8e894a24b0e..34318c612b4691 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
"Bad op type");
+ auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
+ if (nontemporalAttr && nontemporalAttr.getValue()) {
+ return std::pair{
+ spirv::MemoryAccessAttr::get(accessedPtr.getContext(),
+ spirv::MemoryAccess::Nontemporal),
+ IntegerAttr{}};
+ }
+
auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
spirv::attributeName<spirv::MemoryAccess>());
auto memrefAlignment =
@@ -623,7 +631,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
if (!loadPtr)
return failure();
- AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+ AlignmentRequirements requiredAlignment =
+ calculateRequiredAlignment(loadPtr, loadOp);
if (failed(requiredAlignment))
return rewriter.notifyMatchFailure(
loadOp, "failed to determine alignment requirements");
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index aa05fd9bc8ca89..e03b7bdf357dd5 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -431,3 +431,17 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
}
}
+
+// -----
+
+// Check nontemporal attribute
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+ func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
+ %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+// CHECK: spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
+ memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+// CHECK: spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
+ return
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/82119
More information about the Mlir-commits
mailing list