[Mlir-commits] [mlir] [AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads. (PR #145395)
Alan Li
llvmlistbot at llvm.org
Tue Jun 24 10:26:37 PDT 2025
================
@@ -1100,6 +1100,60 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct TransposeLoadOpLowering
+ : public ConvertOpToLLVMPattern<TransposeLoadOp> {
+ TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset != kGfx950)
+ return op.emitOpError("Non-gfx950 chipset not supported");
+
+ Location loc = op.getLoc();
+ auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+ auto resultType = cast<VectorType>(op.getResult().getType());
+ Value srcPtr =
+ getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+ (adaptor.getSrcIndices()));
+
+ size_t numElements = resultType.getNumElements();
+ size_t elementTypeSize =
+ resultType.getElementType().getIntOrFloatBitWidth();
+
+ switch (elementTypeSize) {
+ case 4:
+ assert(numElements == 16);
+ rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(op, resultType,
+ srcPtr);
+ break;
+ case 6:
+ // To use ds_read_tr6_b96, the load size is vector<3xi32>.
+ // TODO: support native 6-bit data types.
+ assert(numElements == 16);
+ rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr6_b96>(op, resultType,
+ srcPtr);
+ break;
+ case 8:
+ assert(numElements == 8);
+ rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(op, resultType,
+ srcPtr);
+ break;
+ case 16:
+ assert(numElements == 4);
+ rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, resultType,
----------------
lialan wrote:
Found that, updated.
https://github.com/llvm/llvm-project/pull/145395
More information about the Mlir-commits
mailing list