[Mlir-commits] [mlir] [mlir][spirv][memref] Calculate alignment for `PhysicalStorageBuffer`s (PR #80243)
Quinn Dawkins
llvmlistbot at llvm.org
Thu Feb 1 14:14:11 PST 2024
================
@@ -439,6 +445,52 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
// LoadOp
//===----------------------------------------------------------------------===//
+using AlignmentRequirements =
+ FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
+
+/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
+/// any.
+static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
+ auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
+ if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer)
+ return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
+
+ // PhysicalStorageBuffers require the `Aligned` attribute.
+ auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
+ if (!pointeeType)
+ return failure();
+
+ // For scalar types, the alignment is determined by their size.
+ std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
+ if (!sizeInBytes.has_value())
+ return failure();
+
+ MLIRContext *ctx = accessedPtr.getContext();
+ auto memAccessAttr =
+ spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
+ auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
+ return std::pair{memAccessAttr, alignment};
+}
+
+/// Given an accessed SPIR-V pointer and the original memref load/store
+/// `memAccess` op, calculates the alignment requirements, if any. Takes into
+/// account the alignment attributes applied to the load/store op.
+static AlignmentRequirements
+calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
+ assert(memrefAccessOp);
+ assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
+ "Bad op type");
----------------
qedawkins wrote:
Could this be replaced with a template and a static assert? All the users seem to be pattern rewrites that know the op type.
https://github.com/llvm/llvm-project/pull/80243
More information about the Mlir-commits
mailing list