[Mlir-commits] [mlir] [MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU (PR #133498)
Krzysztof Drewniak
llvmlistbot at llvm.org
Tue Apr 1 16:42:29 PDT 2025
================
@@ -903,6 +903,81 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct GlobalLoadLDSOpLowering
+ : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
+ GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(GlobalLoadLDSOp op, GlobalLoadLDSOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+
+ auto elemType = cast<MemRefType>(op.getDst().getType()).getElementType();
+ size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
+ if (elemSizeInBits % 8 != 0)
+ return op.emitOpError("element size must be a multiple of 8");
+
+ // TODO: instead of only transfering one element per thread, we could
+ // augment it to transfer multiple elements per thread by issuing multiple
+ // `global_load_lds` instructions.
+ auto loadWidth = elemSizeInBits / 8;
+
+ const Chipset GlobalLoadEnabled{9, 0x4, 0x0};
+ if (chipset < GlobalLoadEnabled)
+ return op.emitOpError("chipset not supported");
+
+ // Currently only 1, 2, and 4 byte loads are supported.
+ if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
+ return op.emitOpError("chipset unsupported element size");
+
+ // Return pair of {base pointer, linearized index}.
+ auto getBasePtrAndLinearizedIndex =
+ [&](Value memref, MemRefType memrefType,
+ ValueRange indices) -> std::optional<std::pair<Value, Value>> {
+ MemRefDescriptor memRefDescriptor(memref);
+ int64_t offset = 0;
+ SmallVector<int64_t, 5> strides;
+ if (failed(memrefType.getStridesAndOffset(strides, offset)))
+ return {};
+ return std::make_pair(
+ memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
+ memrefType),
+ getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
+ };
+
+ auto optSrcBuffer = getBasePtrAndLinearizedIndex(
+ adaptor.getSrc(), cast<MemRefType>(op.getSrc().getType()),
+ op.getSrcIndices());
+ if (!optSrcBuffer)
+ return op.emitOpError("failed to flatten source memref indices");
+ auto optDstBuffer = getBasePtrAndLinearizedIndex(
+ adaptor.getDst(), cast<MemRefType>(op.getDst().getType()),
+ op.getDstIndices());
+ if (!optDstBuffer)
+ return op.emitOpError("failed to flatten destination memref indices");
+
+ Type srcPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
+ Type dstPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
+ Value srcPtr = rewriter.create<LLVM::GEPOp>(
----------------
krzysz00 wrote:
This should follow `memref.load` and/or `memref.store`'s lowerings for getting the pointer.
https://github.com/llvm/llvm-project/pull/133498
More information about the Mlir-commits
mailing list