[Mlir-commits] [mlir] [mlir][spirv] Retain nontemporal attribute when converting memref load/store (PR #82119)
Artem Tyurin
llvmlistbot at llvm.org
Mon Feb 26 02:32:32 PST 2024
https://github.com/agentcooper updated https://github.com/llvm/llvm-project/pull/82119
>From c058eb4a564ea9f2deabe26e3691064e0a58b4a3 Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Sat, 17 Feb 2024 17:13:58 +0100
Subject: [PATCH 1/2] [mlir][spirv] Retain nontemporal attribute when
converting memref load/store
---
.../lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 11 ++++++++++-
.../Conversion/MemRefToSPIRV/memref-to-spirv.mlir | 14 ++++++++++++++
2 files changed, 24 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 57d8e894a24b0e..34318c612b4691 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -481,6 +481,14 @@ calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
"Bad op type");
+ auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
+ if (nontemporalAttr && nontemporalAttr.getValue()) {
+ return std::pair{
+ spirv::MemoryAccessAttr::get(accessedPtr.getContext(),
+ spirv::MemoryAccess::Nontemporal),
+ IntegerAttr{}};
+ }
+
auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
spirv::attributeName<spirv::MemoryAccess>());
auto memrefAlignment =
@@ -623,7 +631,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
if (!loadPtr)
return failure();
- AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+ AlignmentRequirements requiredAlignment =
+ calculateRequiredAlignment(loadPtr, loadOp);
if (failed(requiredAlignment))
return rewriter.notifyMatchFailure(
loadOp, "failed to determine alignment requirements");
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index aa05fd9bc8ca89..e03b7bdf357dd5 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -431,3 +431,17 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
}
}
+
+// -----
+
+// Check nontemporal attribute
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+ func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
+ %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+// CHECK: spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
+ memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
+// CHECK: spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
+ return
+ }
+}
>From 7f46ec8d9794c67b844c8e0a0e2d64b9d924996a Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Mon, 26 Feb 2024 11:32:19 +0100
Subject: [PATCH 2/2] Set nontemporal only if alignment is not set
---
.../MemRefToSPIRV/MemRefToSPIRV.cpp | 77 ++++++++++---------
.../MemRefToSPIRV/memref-to-spirv.mlir | 19 ++++-
2 files changed, 57 insertions(+), 39 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 34318c612b4691..ec9d1261fef058 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -445,15 +445,23 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
// LoadOp
//===----------------------------------------------------------------------===//
-using AlignmentRequirements =
+using MemoryRequirements =
FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
/// any.
-static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
+static MemoryRequirements calculateMemoryRequirements(Value accessedPtr,
+ bool isNontemporal) {
+ MLIRContext *ctx = accessedPtr.getContext();
auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
- if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer)
+ if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
+ if (isNontemporal) {
+ return std::pair{
+ spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Nontemporal),
+ IntegerAttr{}};
+ }
return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
+ }
// PhysicalStorageBuffers require the `Aligned` attribute.
auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
@@ -465,7 +473,6 @@ static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
if (!sizeInBytes.has_value())
return failure();
- MLIRContext *ctx = accessedPtr.getContext();
auto memAccessAttr =
spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
@@ -475,28 +482,22 @@ static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
/// Given an accessed SPIR-V pointer and the original memref load/store
/// `memAccess` op, calculates the alignment requirements, if any. Takes into
/// account the alignment attributes applied to the load/store op.
-static AlignmentRequirements
-calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
- assert(memrefAccessOp);
- assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
- "Bad op type");
-
- auto nontemporalAttr = memrefAccessOp->getAttrOfType<BoolAttr>("nontemporal");
- if (nontemporalAttr && nontemporalAttr.getValue()) {
- return std::pair{
- spirv::MemoryAccessAttr::get(accessedPtr.getContext(),
- spirv::MemoryAccess::Nontemporal),
- IntegerAttr{}};
- }
-
- auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
+template <class LoadOrStoreOp>
+static MemoryRequirements
+calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
+ static_assert(
+ llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
+ "Must be called on either memref::LoadOp or memref::StoreOp");
+
+ Operation *op = loadOrStoreOp.getOperation();
+ auto memrefMemAccess = op->getAttrOfType<spirv::MemoryAccessAttr>(
spirv::attributeName<spirv::MemoryAccess>());
- auto memrefAlignment =
- memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
+ auto memrefAlignment = op->getAttrOfType<IntegerAttr>("alignment");
if (memrefMemAccess && memrefAlignment)
return std::pair{memrefMemAccess, memrefAlignment};
- return calculateRequiredAlignment(accessedPtr);
+ return calculateMemoryRequirements(accessedPtr,
+ loadOrStoreOp.getNontemporal());
}
LogicalResult
@@ -546,13 +547,13 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
// If the rewritten load op has the same bit width, use the loading value
// directly.
if (srcBits == dstBits) {
- AlignmentRequirements alignmentRequirements =
- calculateRequiredAlignment(accessChain, loadOp);
- if (failed(alignmentRequirements))
+ MemoryRequirements memoryRequirements =
+ calculateMemoryRequirements(accessChain, loadOp);
+ if (failed(memoryRequirements))
return rewriter.notifyMatchFailure(
- loadOp, "failed to determine alignment requirements");
+ loadOp, "failed to determine memory requirements");
- auto [memoryAccess, alignment] = *alignmentRequirements;
+ auto [memoryAccess, alignment] = *memoryRequirements;
Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
memoryAccess, alignment);
if (isBool)
@@ -576,13 +577,13 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
assert(accessChainOp.getIndices().size() == 2);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
- AlignmentRequirements alignmentRequirements =
- calculateRequiredAlignment(adjustedPtr, loadOp);
- if (failed(alignmentRequirements))
+ MemoryRequirements memoryRequirements =
+ calculateMemoryRequirements(adjustedPtr, loadOp);
+ if (failed(memoryRequirements))
return rewriter.notifyMatchFailure(
- loadOp, "failed to determine alignment requirements");
+ loadOp, "failed to determine memory requirements");
- auto [memoryAccess, alignment] = *alignmentRequirements;
+ auto [memoryAccess, alignment] = *memoryRequirements;
Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
memoryAccess, alignment);
@@ -631,8 +632,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
if (!loadPtr)
return failure();
- AlignmentRequirements requiredAlignment =
- calculateRequiredAlignment(loadPtr, loadOp);
+ MemoryRequirements requiredAlignment =
+ calculateMemoryRequirements(loadPtr, loadOp);
if (failed(requiredAlignment))
return rewriter.notifyMatchFailure(
loadOp, "failed to determine alignment requirements");
@@ -698,8 +699,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
assert(dstBits % srcBits == 0);
if (srcBits == dstBits) {
- AlignmentRequirements requiredAlignment =
- calculateRequiredAlignment(accessChain);
+ MemoryRequirements requiredAlignment =
+ calculateMemoryRequirements(accessChain, storeOp);
if (failed(requiredAlignment))
return rewriter.notifyMatchFailure(
storeOp, "failed to determine alignment requirements");
@@ -856,8 +857,8 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
if (!storePtr)
return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
- AlignmentRequirements requiredAlignment =
- calculateRequiredAlignment(storePtr, storeOp);
+ MemoryRequirements requiredAlignment =
+ calculateMemoryRequirements(storePtr, storeOp);
if (failed(requiredAlignment))
return rewriter.notifyMatchFailure(
storeOp, "failed to determine alignment requirements");
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index e03b7bdf357dd5..bef833967d1c6f 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -436,7 +436,15 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
// Check nontemporal attribute
-module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [
+ Shader,
+ PhysicalStorageBufferAddresses
+ ], [
+ SPV_KHR_storage_buffer_storage_class,
+ SPV_KHR_physical_storage_buffer
+ ]>, #spirv.resource_limits<>>
+} {
func.func @load_nontemporal(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
%0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Load "StorageBuffer" %{{.+}} ["Nontemporal"] : f32
@@ -444,4 +452,13 @@ module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader
// CHECK: spirv.Store "StorageBuffer" %{{.+}}, %{{.+}} ["Nontemporal"] : f32
return
}
+
+ // Nontemporal attribute is ignored in case of alignment
+ func.func @load_nontemporal_ignored(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ %0 = memref.load %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+// CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : f32
+ memref.store %0, %arg0[] {nontemporal = true} : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+// CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : f32
+ return
+ }
}
More information about the Mlir-commits
mailing list