[Mlir-commits] [mlir] [mlir][spirv][memref] Calculate alignment for `PhysicalStorageBuffer`s (PR #80243)

Jakub Kuderski llvmlistbot at llvm.org
Thu Feb 1 15:31:29 PST 2024


================
@@ -439,6 +445,52 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
 // LoadOp
 //===----------------------------------------------------------------------===//
 
+using AlignmentRequirements =
+    FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
+
+/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
+/// any.
+static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
+  auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
+  if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer)
+    return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
+
+  // PhysicalStorageBuffers require the `Aligned` attribute.
+  auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
+  if (!pointeeType)
+    return failure();
+
+  // For scalar types, the alignment is determined by their size.
+  std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
+  if (!sizeInBytes.has_value())
+    return failure();
+
+  MLIRContext *ctx = accessedPtr.getContext();
+  auto memAccessAttr =
+      spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
+  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
+  return std::pair{memAccessAttr, alignment};
+}
+
+/// Given an accessed SPIR-V pointer and the original memref load/store
+/// `memAccess` op, calculates the alignment requirements, if any. Takes into
+/// account the alignment attributes applied to the load/store op.
+static AlignmentRequirements
+calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
+  assert(memrefAccessOp);
+  assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
+         "Bad op type");
----------------
kuhar wrote:

We could but I prefer it this way -- less code duplication and better code autocompletion. Because it's an internal function, I don't think the benefit of static_assert over runtime assert justifies the cost.

https://github.com/llvm/llvm-project/pull/80243


More information about the Mlir-commits mailing list