[Mlir-commits] [mlir] 8fd0bce - [mlir][spirv][memref] Calculate alignment for `PhysicalStorageBuffer`s (#80243)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 1 15:33:29 PST 2024
Author: Jakub Kuderski
Date: 2024-02-01T18:33:26-05:00
New Revision: 8fd0bce43c4c8334bcb31d214a32260914f59515
URL: https://github.com/llvm/llvm-project/commit/8fd0bce43c4c8334bcb31d214a32260914f59515
DIFF: https://github.com/llvm/llvm-project/commit/8fd0bce43c4c8334bcb31d214a32260914f59515.diff
LOG: [mlir][spirv][memref] Calculate alignment for `PhysicalStorageBuffer`s (#80243)
The SPIR-V spec requires that memory accesses to
`PhysicalStorageBuffer`s are annotated with appropriate alignment
attributes [1]. Calculate these based on memref alignment attributes or
scalar type sizes.
[1] Otherwise spirv-val complains:
```
[VULKAN] ! Validation Error: [ VUID-VkShaderModuleCreateInfo-pCode-01379 ] | MessageID = 0x2a1bf17f | SPIR-V module not valid: [VUID-StandaloneSpirv-PhysicalStorageBuffer64-04708] Memory accesses with PhysicalStorageBuffer must use Aligned.
%48 = OpLoad %float %47
```
Added:
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index acddb3c4da461..57d8e894a24b0 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -12,12 +12,18 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/Debug.h"
+#include <cassert>
#include <optional>
#define DEBUG_TYPE "memref-to-spirv-pattern"
@@ -439,6 +445,52 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
// LoadOp
//===----------------------------------------------------------------------===//
+using AlignmentRequirements =
+ FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
+
+/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
+/// any.
+static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
+ auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
+ if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer)
+ return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
+
+ // PhysicalStorageBuffers require the `Aligned` attribute.
+ auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
+ if (!pointeeType)
+ return failure();
+
+ // For scalar types, the alignment is determined by their size.
+ std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
+ if (!sizeInBytes.has_value())
+ return failure();
+
+ MLIRContext *ctx = accessedPtr.getContext();
+ auto memAccessAttr =
+ spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
+ auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
+ return std::pair{memAccessAttr, alignment};
+}
+
+/// Given an accessed SPIR-V pointer and the original memref load/store
+/// `memAccess` op, calculates the alignment requirements, if any. Takes into
+/// account the alignment attributes applied to the load/store op.
+static AlignmentRequirements
+calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
+ assert(memrefAccessOp);
+ assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
+ "Bad op type");
+
+ auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
+ spirv::attributeName<spirv::MemoryAccess>());
+ auto memrefAlignment =
+ memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
+ if (memrefMemAccess && memrefAlignment)
+ return std::pair{memrefMemAccess, memrefAlignment};
+
+ return calculateRequiredAlignment(accessedPtr);
+}
+
LogicalResult
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -486,7 +538,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
// If the rewritten load op has the same bit width, use the loading value
// directly.
if (srcBits == dstBits) {
- Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
+ AlignmentRequirements alignmentRequirements =
+ calculateRequiredAlignment(accessChain, loadOp);
+ if (failed(alignmentRequirements))
+ return rewriter.notifyMatchFailure(
+ loadOp, "failed to determine alignment requirements");
+
+ auto [memoryAccess, alignment] = *alignmentRequirements;
+ Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
+ memoryAccess, alignment);
if (isBool)
loadVal = castIntNToBool(loc, loadVal, rewriter);
rewriter.replaceOp(loadOp, loadVal);
@@ -508,11 +568,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
assert(accessChainOp.getIndices().size() == 2);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
- Value spvLoadOp = rewriter.create<spirv::LoadOp>(
- loc, dstType, adjustedPtr,
- loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
- spirv::attributeName<spirv::MemoryAccess>()),
- loadOp->getAttrOfType<IntegerAttr>("alignment"));
+ AlignmentRequirements alignmentRequirements =
+ calculateRequiredAlignment(adjustedPtr, loadOp);
+ if (failed(alignmentRequirements))
+ return rewriter.notifyMatchFailure(
+ loadOp, "failed to determine alignment requirements");
+
+ auto [memoryAccess, alignment] = *alignmentRequirements;
+ Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
+ memoryAccess, alignment);
// Shift the bits to the rightmost.
// ____XXXX________ -> ____________XXXX
@@ -552,14 +616,21 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
if (memrefType.getElementType().isSignlessInteger())
return failure();
- auto loadPtr = spirv::getElementPtr(
+ Value loadPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
adaptor.getIndices(), loadOp.getLoc(), rewriter);
if (!loadPtr)
return failure();
- rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
+ AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+ if (failed(requiredAlignment))
+ return rewriter.notifyMatchFailure(
+ loadOp, "failed to determine alignment requirements");
+
+ auto [memAccessAttr, alignment] = *requiredAlignment;
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memAccessAttr,
+ alignment);
return success();
}
@@ -618,10 +689,18 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
assert(dstBits % srcBits == 0);
if (srcBits == dstBits) {
+ AlignmentRequirements requiredAlignment =
+ calculateRequiredAlignment(accessChain);
+ if (failed(requiredAlignment))
+ return rewriter.notifyMatchFailure(
+ storeOp, "failed to determine alignment requirements");
+
+ auto [memAccessAttr, alignment] = *requiredAlignment;
Value storeVal = adaptor.getValue();
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
- rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal);
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
+ memAccessAttr, alignment);
return success();
}
@@ -768,8 +847,15 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
if (!storePtr)
return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
- rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
- adaptor.getValue());
+ AlignmentRequirements requiredAlignment =
+ calculateRequiredAlignment(storePtr, storeOp);
+ if (failed(requiredAlignment))
+ return rewriter.notifyMatchFailure(
+ storeOp, "failed to determine alignment requirements");
+
+ auto [memAccessAttr, alignment] = *requiredAlignment;
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+ storeOp, storePtr, adaptor.getValue(), memAccessAttr, alignment);
return success();
}
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index a8b550367d5fa..aa05fd9bc8ca8 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,17 +1,19 @@
-// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" -cse %s -o - | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
// Check that with proper compute and storage extensions, we don't need to
// perform special tricks.
module attributes {
spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.0,
+ #spirv.vce<v1.5,
[
Shader, Int8, Int16, Int64, Float16, Float64,
StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
- StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8
+ StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8,
+ PhysicalStorageBufferAddresses
],
- [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer]>,
+ #spirv.resource_limits<>>
} {
// CHECK-LABEL: @load_store_zero_rank_float
@@ -119,6 +121,51 @@ func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>,
return
}
+// CHECK-LABEL: @load_store_i32_physical
+func.func @load_store_i32_physical(%arg0: memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : i32
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : i32
+ %0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_i8_physical
+func.func @load_store_i8_physical(%arg0: memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
+ %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_i1_physical
+func.func @load_store_i1_physical(%arg0: memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
+ %0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_f32_physical
+func.func @load_store_f32_physical(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : f32
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : f32
+ %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_f16_physical
+func.func @load_store_f16_physical(%arg0: memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 2] : f16
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 2] : f16
+ %0 = memref.load %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
} // end module
// -----
More information about the Mlir-commits
mailing list