[Mlir-commits] [mlir] 8fd0bce - [mlir][spirv][memref] Calculate alignment for `PhysicalStorageBuffer`s (#80243)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 1 15:33:29 PST 2024


Author: Jakub Kuderski
Date: 2024-02-01T18:33:26-05:00
New Revision: 8fd0bce43c4c8334bcb31d214a32260914f59515

URL: https://github.com/llvm/llvm-project/commit/8fd0bce43c4c8334bcb31d214a32260914f59515
DIFF: https://github.com/llvm/llvm-project/commit/8fd0bce43c4c8334bcb31d214a32260914f59515.diff

LOG: [mlir][spirv][memref] Calculate alignment for `PhysicalStorageBuffer`s (#80243)

The SPIR-V spec requires that memory accesses to
`PhysicalStorageBuffer`s are annotated with appropriate alignment
attributes [1]. Calculate these based on memref alignment attributes or
scalar type sizes.

[1] Otherwise spirv-val complains:
```
[VULKAN] ! Validation Error: [ VUID-VkShaderModuleCreateInfo-pCode-01379 ] | MessageID = 0x2a1bf17f | SPIR-V module not valid: [VUID-StandaloneSpirv-PhysicalStorageBuffer64-04708] Memory accesses with PhysicalStorageBuffer must use Aligned.
  %48 = OpLoad %float %47
```

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index acddb3c4da461..57d8e894a24b0 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -12,12 +12,18 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/Support/Debug.h"
+#include <cassert>
 #include <optional>
 
 #define DEBUG_TYPE "memref-to-spirv-pattern"
@@ -439,6 +445,52 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
 // LoadOp
 //===----------------------------------------------------------------------===//
 
+using AlignmentRequirements =
+    FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
+
+/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
+/// any.
+static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
+  auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
+  if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer)
+    return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
+
+  // PhysicalStorageBuffers require the `Aligned` attribute.
+  auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
+  if (!pointeeType)
+    return failure();
+
+  // For scalar types, the alignment is determined by their size.
+  std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
+  if (!sizeInBytes.has_value())
+    return failure();
+
+  MLIRContext *ctx = accessedPtr.getContext();
+  auto memAccessAttr =
+      spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
+  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
+  return std::pair{memAccessAttr, alignment};
+}
+
+/// Given an accessed SPIR-V pointer and the original memref load/store
+/// `memAccess` op, calculates the alignment requirements, if any. Takes into
+/// account the alignment attributes applied to the load/store op.
+static AlignmentRequirements
+calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
+  assert(memrefAccessOp);
+  assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
+         "Bad op type");
+
+  auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
+      spirv::attributeName<spirv::MemoryAccess>());
+  auto memrefAlignment =
+      memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
+  if (memrefMemAccess && memrefAlignment)
+    return std::pair{memrefMemAccess, memrefAlignment};
+
+  return calculateRequiredAlignment(accessedPtr);
+}
+
 LogicalResult
 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {
@@ -486,7 +538,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   // If the rewritten load op has the same bit width, use the loading value
   // directly.
   if (srcBits == dstBits) {
-    Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
+    AlignmentRequirements alignmentRequirements =
+        calculateRequiredAlignment(accessChain, loadOp);
+    if (failed(alignmentRequirements))
+      return rewriter.notifyMatchFailure(
+          loadOp, "failed to determine alignment requirements");
+
+    auto [memoryAccess, alignment] = *alignmentRequirements;
+    Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
+                                                   memoryAccess, alignment);
     if (isBool)
       loadVal = castIntNToBool(loc, loadVal, rewriter);
     rewriter.replaceOp(loadOp, loadVal);
@@ -508,11 +568,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   assert(accessChainOp.getIndices().size() == 2);
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
                                                    srcBits, dstBits, rewriter);
-  Value spvLoadOp = rewriter.create<spirv::LoadOp>(
-      loc, dstType, adjustedPtr,
-      loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
-          spirv::attributeName<spirv::MemoryAccess>()),
-      loadOp->getAttrOfType<IntegerAttr>("alignment"));
+  AlignmentRequirements alignmentRequirements =
+      calculateRequiredAlignment(adjustedPtr, loadOp);
+  if (failed(alignmentRequirements))
+    return rewriter.notifyMatchFailure(
+        loadOp, "failed to determine alignment requirements");
+
+  auto [memoryAccess, alignment] = *alignmentRequirements;
+  Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
+                                                   memoryAccess, alignment);
 
   // Shift the bits to the rightmost.
   // ____XXXX________ -> ____________XXXX
@@ -552,14 +616,21 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
   if (memrefType.getElementType().isSignlessInteger())
     return failure();
-  auto loadPtr = spirv::getElementPtr(
+  Value loadPtr = spirv::getElementPtr(
       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
       adaptor.getIndices(), loadOp.getLoc(), rewriter);
 
   if (!loadPtr)
     return failure();
 
-  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
+  AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+  if (failed(requiredAlignment))
+    return rewriter.notifyMatchFailure(
+        loadOp, "failed to determine alignment requirements");
+
+  auto [memAccessAttr, alignment] = *requiredAlignment;
+  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memAccessAttr,
+                                             alignment);
   return success();
 }
 
@@ -618,10 +689,18 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   assert(dstBits % srcBits == 0);
 
   if (srcBits == dstBits) {
+    AlignmentRequirements requiredAlignment =
+        calculateRequiredAlignment(accessChain);
+    if (failed(requiredAlignment))
+      return rewriter.notifyMatchFailure(
+          storeOp, "failed to determine alignment requirements");
+
+    auto [memAccessAttr, alignment] = *requiredAlignment;
     Value storeVal = adaptor.getValue();
     if (isBool)
       storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
-    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal);
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
+                                                memAccessAttr, alignment);
     return success();
   }
 
@@ -768,8 +847,15 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   if (!storePtr)
     return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
 
-  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
-                                              adaptor.getValue());
+  AlignmentRequirements requiredAlignment =
+      calculateRequiredAlignment(storePtr, storeOp);
+  if (failed(requiredAlignment))
+    return rewriter.notifyMatchFailure(
+        storeOp, "failed to determine alignment requirements");
+
+  auto [memAccessAttr, alignment] = *requiredAlignment;
+  rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+      storeOp, storePtr, adaptor.getValue(), memAccessAttr, alignment);
   return success();
 }
 

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index a8b550367d5fa..aa05fd9bc8ca8 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,17 +1,19 @@
-// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" -cse %s -o - | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
 
 // Check that with proper compute and storage extensions, we don't need to
 // perform special tricks.
 
 module attributes {
   spirv.target_env = #spirv.target_env<
-    #spirv.vce<v1.0,
+    #spirv.vce<v1.5,
       [
         Shader, Int8, Int16, Int64, Float16, Float64,
         StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
-        StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8
+        StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8,
+        PhysicalStorageBufferAddresses
       ],
-      [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+      [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer]>,
+      #spirv.resource_limits<>>
 } {
 
 // CHECK-LABEL: @load_store_zero_rank_float
@@ -119,6 +121,51 @@ func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>,
   return
 }
 
+// CHECK-LABEL: @load_store_i32_physical
+func.func @load_store_i32_physical(%arg0: memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : i32
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : i32
+  %0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_i8_physical
+func.func @load_store_i8_physical(%arg0: memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
+  %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_i1_physical
+func.func @load_store_i1_physical(%arg0: memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
+  %0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_f32_physical
+func.func @load_store_f32_physical(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : f32
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : f32
+  %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_f16_physical
+func.func @load_store_f16_physical(%arg0: memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 2] : f16
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 2] : f16
+  %0 = memref.load %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
 } // end module
 
 // -----


        


More information about the Mlir-commits mailing list