[Mlir-commits] [mlir] [AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads. (PR #145395)

Alan Li llvmlistbot at llvm.org
Tue Jun 24 10:26:37 PDT 2025


================
@@ -1100,6 +1100,60 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct TransposeLoadOpLowering
+    : public ConvertOpToLLVMPattern<TransposeLoadOp> {
+  TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset != kGfx950)
+      return op.emitOpError("Non-gfx950 chipset not supported");
+
+    Location loc = op.getLoc();
+    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+    auto resultType = cast<VectorType>(op.getResult().getType());
+    Value srcPtr =
+        getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+                             (adaptor.getSrcIndices()));
+
+    size_t numElements = resultType.getNumElements();
+    size_t elementTypeSize =
+        resultType.getElementType().getIntOrFloatBitWidth();
+
+    switch (elementTypeSize) {
+    case 4:
+      assert(numElements == 16);
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(op, resultType,
+                                                          srcPtr);
+      break;
+    case 6:
+      // To use ds_read_tr6_b96, the load size is vector<3xi32>.
+      // TODO: support native 6-bit data types.
+      assert(numElements == 16);
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr6_b96>(op, resultType,
+                                                          srcPtr);
+      break;
+    case 8:
+      assert(numElements == 8);
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(op, resultType,
+                                                          srcPtr);
+      break;
+    case 16:
+      assert(numElements == 4);
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, resultType,
----------------
lialan wrote:

Found that, updated.

https://github.com/llvm/llvm-project/pull/145395


More information about the Mlir-commits mailing list