[Mlir-commits] [mlir] [AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads. (PR #145395)

Alan Li llvmlistbot at llvm.org
Tue Jun 24 10:22:42 PDT 2025


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/145395

>From 2009ede2140d8f260c77b50ef0e9c3085f79c2e3 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 23 Jun 2025 15:25:06 -0400
Subject: [PATCH 01/12] [AMDGPU] Adding AMDGPU dialect wrapper for ROCDL
 transpose loads.

* 1-to-1 mapping wrapper op.
* Direct lowering from AMDGPU wrapper to ROCDL intrinsics.
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 21 +++++++++
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 47 ++++++++++++++++++-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 18 +++++++
 3 files changed, 84 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d58558ac32884..003aff6d38da0 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -898,6 +898,27 @@ def AMDGPU_GatherToLDSOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_TransposeLoadOp :
+    AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
+    Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
+    Results<(outs MFMAInTypes:$dst)> {
+  let summary = "MLIR wrapper for CDNA Transpose Load instructions";
+  let description = [{
+    The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
+
+    Operands:
+    * `$src`: LDS memref to read from.
+    * `$srcIndices`: indices into `$src` to read from for this thread.
+    * `$dst`: target register this transpose load instruction will write to.
+
+    Note: Lowering is only supported on gfx950 and up.
+  }];
+  let assemblyFormat = [{
+    $src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($dst)
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_ScaledMFMAOp :
     AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
                         Pure]>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 700563460f525..62ed1d871bcfd 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1100,6 +1100,49 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct TransposeLoadOpLowering
+    : public ConvertOpToLLVMPattern<TransposeLoadOp> {
+  TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx950)
+      return op.emitOpError("Non-gfx950 chipset not supported");
+
+    Location loc = op.getLoc();
+    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+    Value srcPtr =
+        getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+                             (adaptor.getSrcIndices()));
+    auto elementTypeSize = cast<VectorType>(op.getDst().getType())
+                               .getElementType()
+                               .getIntOrFloatBitWidth();
+
+    // TODO: support ds_read_tr16_b64 intrinsic.
+    switch (elementTypeSize) {
+    case 4:
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(
+          op, op.getDst().getType(), srcPtr);
+      break;
+    case 8:
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(
+          op, op.getDst().getType(), srcPtr);
+      break;
+    case 16:
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(
+          op, op.getDst().getType(), srcPtr);
+      break;
+    default:
+      return op.emitOpError("Unsupported element size for transpose load");
+    }
+    return success();
+  }
+};
+
 struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
   GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1792,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
            ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
            PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
-           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
-                                                                 chipset);
+           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+           TransposeLoadOpLowering>(converter, chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 0d0add3094666..9d0d76d0ad843 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -524,6 +524,24 @@ LogicalResult GatherToLDSOp::verify() {
   return success();
 }
 
+LogicalResult TransposeLoadOp::verify() {
+  MemRefType srcType = cast<MemRefType>(getSrc().getType());
+
+  if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
+    return emitOpError("source memory address space must be Workgroup");
+
+  // TODO: support 6-bit element type vectors.
+  auto transferType = dyn_cast<VectorType>(getDst().getType());
+  if (!transferType)
+    return emitOpError("destination type must be a vector type");
+  size_t transferSize =
+      transferType.getNumElements() * transferType.getElementTypeBitWidth();
+  if (transferSize != 64)
+    return emitOpError("Transferring type size must be 64 bits");
+
+  return success();
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES

>From 50d19a627afdc37182bf5faa5c2ab562cff59674 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 23 Jun 2025 15:42:01 -0400
Subject: [PATCH 02/12] Adding a test file

---
 .../AMDGPUToROCDL/transpose_load.mlir          | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir

diff --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
new file mode 100644
index 0000000000000..3e20d51efc93f
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+
+#gpu_lds_addrspace = 3
+#amdgpu_fat_buffer_addrspace = 7
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_4xf16
+func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem : memref<128x72xf16, #gpu_lds_addrspace>) -> vector<4xf16> {
+  // CHECK: rocdl.ds.read.tr16.b64
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x72xf16, #gpu_lds_addrspace> -> vector<4xf16>
+  return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_8xi8
+func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x128xi8, #gpu_lds_addrspace>) -> vector<8xi8> {
+  // CHECK: rocdl.ds.read.tr8.b64
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, #gpu_lds_addrspace> -> vector<8xi8>
+  return %0 : vector<8xi8>
+}

>From 087046a0373355f145124c0bcbde3cace93e03fe Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 23 Jun 2025 15:51:19 -0400
Subject: [PATCH 03/12] Adding 6-bit loads.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  6 ++---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  4 ++++
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 24 +++++++++++++++----
 3 files changed, 27 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 003aff6d38da0..f3cd5e5eeb2da 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -901,7 +901,7 @@ def AMDGPU_GatherToLDSOp :
 def AMDGPU_TransposeLoadOp :
     AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
     Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
-    Results<(outs MFMAInTypes:$dst)> {
+    Results<(outs MFMAInTypes:$result)> {
   let summary = "MLIR wrapper for CDNA Transpose Load instructions";
   let description = [{
     The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
@@ -909,12 +909,12 @@ def AMDGPU_TransposeLoadOp :
     Operands:
     * `$src`: LDS memref to read from.
     * `$srcIndices`: indices into `$src` to read from for this thread.
-    * `$dst`: target register this transpose load instruction will write to.
+    * `$result`: target register this transpose load instruction will write to.
 
     Note: Lowering is only supported on gfx950 and up.
   }];
   let assemblyFormat = [{
-    $src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($dst)
+    $src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($result)
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 62ed1d871bcfd..3d8ba045eb1e3 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1128,6 +1128,10 @@ struct TransposeLoadOpLowering
       rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(
           op, op.getDst().getType(), srcPtr);
       break;
+    case 6:
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr6_b64>(
+          op, op.getDst().getType(), srcPtr);
+      break;
     case 8:
       rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(
           op, op.getDst().getType(), srcPtr);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 9d0d76d0ad843..a72b4031da644 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -536,10 +536,26 @@ LogicalResult TransposeLoadOp::verify() {
     return emitOpError("destination type must be a vector type");
   size_t transferSize =
       transferType.getNumElements() * transferType.getElementTypeBitWidth();
-  if (transferSize != 64)
-    return emitOpError("Transferring type size must be 64 bits");
-
-  return success();
+  size_t elementTypeSize = srcType.getElementType().getIntOrFloatBitWidth();
+
+  // ElementSize -> LoadSize
+  const std::map<int, int> KValidLoadSizeMap = {
+      {4, 64},
+      {6, 96},
+      {8, 64},
+      {16, 64},
+  };
+
+  auto validLoadSize = KValidLoadSizeMap.find(elementTypeSize);
+  if (validLoadSize == KValidLoadSizeMap.end())
+    return emitOpError("Unsupported element type size for transpose load: ")
+           << elementTypeSize << " bits";
+  if (transferSize != validLoadSize->second)
+    return emitOpError("Transferring type size must be ")
+           << validLoadSize->second
+           << " bits for element type size "
+
+           return success();
 }
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"

>From fa30258c78e1de2dbcd68f4315541d5727c55028 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 23 Jun 2025 16:29:07 -0400
Subject: [PATCH 04/12] Adding support for 6-bit loadings.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 18 +++++++++-
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 33 +++++++++++--------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 17 +++++-----
 .../AMDGPUToROCDL/transpose_load.mlir         | 14 ++++++++
 4 files changed, 60 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index f3cd5e5eeb2da..3a08bb6cfcce4 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -898,10 +898,26 @@ def AMDGPU_GatherToLDSOp :
   let hasVerifier = 1;
 }
 
+def F8Types : AnyTypeOf<[
+  F8E8M0FNU,      // 8 exponent, 0 mantissa
+  F8E5M2,         // 5 exponent, 2 mantissa
+  F8E5M2FNUZ,     // 5 exponent, 2 mantissa
+  F8E4M3,         // 4 exponent, 3 mantissa
+  F8E4M3FN,       // 4 exponent, 3 mantissa
+  F8E4M3B11FNUZ,  // 4 exponent, 3 mantissa (with bias 11)
+  F8E3M4          // 3 exponent, 4 mantissa
+]>;
+def F6Types : AnyTypeOf<[F6E2M3FN, F6E3M2FN]>;
+def TrLoadTypes : AnyTypeOf<[VectorOfLengthAndType<[4], [F16, AnyI<16>]>,
+                             VectorOfLengthAndType<[8], [F8Types, AnyI<8>]>,
+                             VectorOfLengthAndType<[16], [AnyI<4>, F6Types]>,
+                             VectorOfLengthAndType<[3], [I32]>,
+                           ]>;
+
 def AMDGPU_TransposeLoadOp :
     AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
     Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
-    Results<(outs MFMAInTypes:$result)> {
+    Results<(outs TrLoadTypes:$result)> {
   let summary = "MLIR wrapper for CDNA Transpose Load instructions";
   let description = [{
     The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3d8ba045eb1e3..61da9639d9539 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1115,30 +1115,37 @@ struct TransposeLoadOpLowering
 
     Location loc = op.getLoc();
     auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+    auto resultType = cast<VectorType>(op.getResult().getType());
     Value srcPtr =
         getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
                              (adaptor.getSrcIndices()));
-    auto elementTypeSize = cast<VectorType>(op.getDst().getType())
-                               .getElementType()
-                               .getIntOrFloatBitWidth();
 
-    // TODO: support ds_read_tr16_b64 intrinsic.
+    size_t numElements = resultType.getNumElements();
+    size_t elementTypeSize =
+        resultType.getElementType().getIntOrFloatBitWidth();
+
     switch (elementTypeSize) {
     case 4:
-      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(
-          op, op.getDst().getType(), srcPtr);
+      assert(numElements == 16);
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(op, resultType,
+                                                          srcPtr);
       break;
-    case 6:
-      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr6_b64>(
-          op, op.getDst().getType(), srcPtr);
+    case 32:
+      // To use ds_read_tr6_b96, the load size is vector<3xi32>.
+      // TODO: support native 6-bit data types.
+      assert(numElements == 3);
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr6_b96>(op, resultType,
+                                                          srcPtr);
       break;
     case 8:
-      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(
-          op, op.getDst().getType(), srcPtr);
+      assert(numElements == 8);
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(op, resultType,
+                                                          srcPtr);
       break;
     case 16:
-      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(
-          op, op.getDst().getType(), srcPtr);
+      assert(numElements == 4);
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, resultType,
+                                                           srcPtr);
       break;
     default:
       return op.emitOpError("Unsupported element size for transpose load");
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index a72b4031da644..6074f3652fb56 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -531,7 +531,7 @@ LogicalResult TransposeLoadOp::verify() {
     return emitOpError("source memory address space must be Workgroup");
 
   // TODO: support 6-bit element type vectors.
-  auto transferType = dyn_cast<VectorType>(getDst().getType());
+  auto transferType = dyn_cast<VectorType>(getType());
   if (!transferType)
     return emitOpError("destination type must be a vector type");
   size_t transferSize =
@@ -539,23 +539,24 @@ LogicalResult TransposeLoadOp::verify() {
   size_t elementTypeSize = srcType.getElementType().getIntOrFloatBitWidth();
 
   // ElementSize -> LoadSize
-  const std::map<int, int> KValidLoadSizeMap = {
+  const std::map<size_t, size_t> KValidLoadSizeMap = {
       {4, 64},
-      {6, 96},
+      {32, 96}, // 6-bit element loads use casted vector<3xi32>
       {8, 64},
       {16, 64},
   };
 
   auto validLoadSize = KValidLoadSizeMap.find(elementTypeSize);
-  if (validLoadSize == KValidLoadSizeMap.end())
+  if (validLoadSize == KValidLoadSizeMap.end()) {
     return emitOpError("Unsupported element type size for transpose load: ")
            << elementTypeSize << " bits";
-  if (transferSize != validLoadSize->second)
+  }
+  if (transferSize != validLoadSize->second) {
     return emitOpError("Transferring type size must be ")
-           << validLoadSize->second
-           << " bits for element type size "
+           << validLoadSize->second << " bits for element type size ";
+  }
 
-           return success();
+  return success();
 }
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
index 3e20d51efc93f..d2a75e445e239 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
@@ -16,3 +16,17 @@ func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : m
   %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, #gpu_lds_addrspace> -> vector<8xi8>
   return %0 : vector<8xi8>
 }
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_16xi4
+func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, #gpu_lds_addrspace>) -> vector<16xi4> {
+  // CHECK: rocdl.ds.read.tr4.b64
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, #gpu_lds_addrspace> -> vector<16xi4>
+  return %0 : vector<16xi4>
+}
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_3xi32
+func.func @transpose_load_to_rocdl_3xi32(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi32, #gpu_lds_addrspace>) -> vector<3xi32> {
+  // CHECK: rocdl.ds.read.tr6.b96
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi32, #gpu_lds_addrspace> -> vector<3xi32>
+  return %0 : vector<3xi32>
+}

>From 4259f63032bc734c9765ff7af69c3bea664d948b Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 23 Jun 2025 18:54:02 -0400
Subject: [PATCH 05/12] Adding check nots.

---
 .../AMDGPUToROCDL/transpose_load.mlir         | 29 ++++++++++++-------
 1 file changed, 18 insertions(+), 11 deletions(-)

diff --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
index d2a75e445e239..a9111d6536286 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
@@ -1,32 +1,39 @@
 // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
-
-#gpu_lds_addrspace = 3
-#amdgpu_fat_buffer_addrspace = 7
+// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx945 2>&1 | FileCheck %s --check-prefix=CHECK-OLD 
 
 // CHECK-LABEL: func @transpose_load_to_rocdl_4xf16
-func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem : memref<128x72xf16, #gpu_lds_addrspace>) -> vector<4xf16> {
+func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem : memref<128x72xf16, 3>) -> vector<4xf16> {
   // CHECK: rocdl.ds.read.tr16.b64
-  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x72xf16, #gpu_lds_addrspace> -> vector<4xf16>
+  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x72xf16, 3> -> vector<4xf16>
   return %0 : vector<4xf16>
 }
 
+// -----
+
 // CHECK-LABEL: func @transpose_load_to_rocdl_8xi8
-func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x128xi8, #gpu_lds_addrspace>) -> vector<8xi8> {
+func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x128xi8, 3>) -> vector<8xi8> {
   // CHECK: rocdl.ds.read.tr8.b64
-  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, #gpu_lds_addrspace> -> vector<8xi8>
+  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, 3> -> vector<8xi8>
   return %0 : vector<8xi8>
 }
 
+// -----
+
 // CHECK-LABEL: func @transpose_load_to_rocdl_16xi4
-func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, #gpu_lds_addrspace>) -> vector<16xi4> {
+func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, 3>) -> vector<16xi4> {
   // CHECK: rocdl.ds.read.tr4.b64
-  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, #gpu_lds_addrspace> -> vector<16xi4>
+  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, 3> -> vector<16xi4>
   return %0 : vector<16xi4>
 }
 
+// -----
+
 // CHECK-LABEL: func @transpose_load_to_rocdl_3xi32
-func.func @transpose_load_to_rocdl_3xi32(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi32, #gpu_lds_addrspace>) -> vector<3xi32> {
+func.func @transpose_load_to_rocdl_3xi32(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi32, 3>) -> vector<3xi32> {
   // CHECK: rocdl.ds.read.tr6.b96
-  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi32, #gpu_lds_addrspace> -> vector<3xi32>
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi32, 3> -> vector<3xi32>
   return %0 : vector<3xi32>
 }

>From c8157f00e485d32a9c92f871c70a7f5817404c40 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 23 Jun 2025 19:12:04 -0400
Subject: [PATCH 06/12] Update the doc in the code.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 15 ++++++++++++++-
 1 file changed, 14 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 3a08bb6cfcce4..416dd79aabebc 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -921,7 +921,20 @@ def AMDGPU_TransposeLoadOp :
   let summary = "MLIR wrapper for CDNA Transpose Load instructions";
   let description = [{
     The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
-
+    The transpose load op represents a subgroup load from LDS memory,
+    where the subgroup of threads collectively reads a matrix from the source
+    memref, with each thread reading a vector of the matrix, and gets a transposed matrix
+    in as the result. That is, each thread reads a vector of the col-major matrix at different
+    indices, and the thread's read result is a vector of the corresponding row of the transposed
+    matrix.
+
+    This op is a direct wrapper around the ROCDL `ds_read_tr` family intrinsics. Please refer
+    to the ROCDL documentation for more details about its exact semantics.
+
+    Format example:
+    ```
+    %0 = amdgpu.transpose_load %src[%srcIndices] : memref<128x256xf16> -> vector<4xf16>
+    ```
     Operands:
     * `$src`: LDS memref to read from.
     * `$srcIndices`: indices into `$src` to read from for this thread.

>From bbb57eaa1f4c3bd670f6b115ef0041bf1a1a26b5 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 23 Jun 2025 21:17:49 -0400
Subject: [PATCH 07/12] Update

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  8 +-----
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 25 +++++++++----------
 2 files changed, 13 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 416dd79aabebc..8d150bec789ec 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -907,17 +907,11 @@ def F8Types : AnyTypeOf<[
   F8E4M3B11FNUZ,  // 4 exponent, 3 mantissa (with bias 11)
   F8E3M4          // 3 exponent, 4 mantissa
 ]>;
-def F6Types : AnyTypeOf<[F6E2M3FN, F6E3M2FN]>;
-def TrLoadTypes : AnyTypeOf<[VectorOfLengthAndType<[4], [F16, AnyI<16>]>,
-                             VectorOfLengthAndType<[8], [F8Types, AnyI<8>]>,
-                             VectorOfLengthAndType<[16], [AnyI<4>, F6Types]>,
-                             VectorOfLengthAndType<[3], [I32]>,
-                           ]>;
 
 def AMDGPU_TransposeLoadOp :
     AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
     Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
-    Results<(outs TrLoadTypes:$result)> {
+    Results<(outs AnyTypeOf<[AnyVectorOfNonZeroRank]>:$result)> {
   let summary = "MLIR wrapper for CDNA Transpose Load instructions";
   let description = [{
     The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 6074f3652fb56..77fe645c38a17 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -530,30 +530,29 @@ LogicalResult TransposeLoadOp::verify() {
   if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
     return emitOpError("source memory address space must be Workgroup");
 
-  // TODO: support 6-bit element type vectors.
   auto transferType = dyn_cast<VectorType>(getType());
   if (!transferType)
     return emitOpError("destination type must be a vector type");
-  size_t transferSize =
-      transferType.getNumElements() * transferType.getElementTypeBitWidth();
+  size_t numElements = transferType.getNumElements();
   size_t elementTypeSize = srcType.getElementType().getIntOrFloatBitWidth();
 
-  // ElementSize -> LoadSize
+  // ElementSize -> NumElements
   const std::map<size_t, size_t> KValidLoadSizeMap = {
-      {4, 64},
-      {32, 96}, // 6-bit element loads use casted vector<3xi32>
-      {8, 64},
-      {16, 64},
+      {4, 16},
+      {32, 3}, // 6-bit element loads use casted vector<3xi32>
+      {8, 8},
+      {16, 4},
   };
 
-  auto validLoadSize = KValidLoadSizeMap.find(elementTypeSize);
-  if (validLoadSize == KValidLoadSizeMap.end()) {
+  auto validNumElems = KValidLoadSizeMap.find(elementTypeSize);
+  if (validNumElems == KValidLoadSizeMap.end()) {
     return emitOpError("Unsupported element type size for transpose load: ")
            << elementTypeSize << " bits";
   }
-  if (transferSize != validLoadSize->second) {
-    return emitOpError("Transferring type size must be ")
-           << validLoadSize->second << " bits for element type size ";
+  if (numElements != validNumElems->second) {
+    return emitOpError(
+               "Transferring type size mismatch: expected num of elements: ")
+           << validNumElems->second;
   }
 
   return success();

>From 60e2c56a68828d4f075dd084baae8ca26ff46fe8 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 23 Jun 2025 21:40:06 -0400
Subject: [PATCH 08/12] Adding loads from different value type.

---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  2 +-
 .../AMDGPUToROCDL/transpose_load.mlir         | 21 +++++++++++++++++++
 2 files changed, 22 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 77fe645c38a17..c22c155b2a60f 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -534,7 +534,7 @@ LogicalResult TransposeLoadOp::verify() {
   if (!transferType)
     return emitOpError("destination type must be a vector type");
   size_t numElements = transferType.getNumElements();
-  size_t elementTypeSize = srcType.getElementType().getIntOrFloatBitWidth();
+  size_t elementTypeSize = transferType.getElementType().getIntOrFloatBitWidth();
 
   // ElementSize -> NumElements
   const std::map<size_t, size_t> KValidLoadSizeMap = {
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
index a9111d6536286..fd1a278c83991 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
@@ -34,6 +34,27 @@ func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem :
 // CHECK-LABEL: func @transpose_load_to_rocdl_3xi32
 func.func @transpose_load_to_rocdl_3xi32(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi32, 3>) -> vector<3xi32> {
   // CHECK: rocdl.ds.read.tr6.b96
+  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
   %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi32, 3> -> vector<3xi32>
   return %0 : vector<3xi32>
 }
+
+// -----
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_i4_memrefxi1
+func.func @transpose_load_to_rocdl_i4_memrefxi1(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi4> {
+  // CHECK: rocdl.ds.read.tr4.b64
+  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi4>
+  return %0 : vector<16xi4>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_i6_memrefxi1
+func.func @transpose_load_to_rocdl_i6_memrefxi1(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<3xi32> {
+  // CHECK: rocdl.ds.read.tr6.b96
+  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<3xi32>
+  return %0 : vector<3xi32>
+}
\ No newline at end of file

>From 9bba79fe0bcc29214b7f2a2b30fe79078a3cca09 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 23 Jun 2025 22:01:13 -0400
Subject: [PATCH 09/12] Update

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td       | 10 ----------
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp |  2 +-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp        |  6 ++++--
 3 files changed, 5 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 8d150bec789ec..c88d37ca7591b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -898,16 +898,6 @@ def AMDGPU_GatherToLDSOp :
   let hasVerifier = 1;
 }
 
-def F8Types : AnyTypeOf<[
-  F8E8M0FNU,      // 8 exponent, 0 mantissa
-  F8E5M2,         // 5 exponent, 2 mantissa
-  F8E5M2FNUZ,     // 5 exponent, 2 mantissa
-  F8E4M3,         // 4 exponent, 3 mantissa
-  F8E4M3FN,       // 4 exponent, 3 mantissa
-  F8E4M3B11FNUZ,  // 4 exponent, 3 mantissa (with bias 11)
-  F8E3M4          // 3 exponent, 4 mantissa
-]>;
-
 def AMDGPU_TransposeLoadOp :
     AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
     Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 61da9639d9539..11f3d7e4c16f7 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1110,7 +1110,7 @@ struct TransposeLoadOpLowering
   LogicalResult
   matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (chipset < kGfx950)
+    if (chipset != kGfx950)
       return op.emitOpError("Non-gfx950 chipset not supported");
 
     Location loc = op.getLoc();
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index c22c155b2a60f..747e3622ef3e0 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 #include <limits>
@@ -534,10 +535,11 @@ LogicalResult TransposeLoadOp::verify() {
   if (!transferType)
     return emitOpError("destination type must be a vector type");
   size_t numElements = transferType.getNumElements();
-  size_t elementTypeSize = transferType.getElementType().getIntOrFloatBitWidth();
+  size_t elementTypeSize =
+      transferType.getElementType().getIntOrFloatBitWidth();
 
   // ElementSize -> NumElements
-  const std::map<size_t, size_t> KValidLoadSizeMap = {
+  const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
       {4, 16},
       {32, 3}, // 6-bit element loads use casted vector<3xi32>
       {8, 8},

>From b5b4e6fdd39adaa04d02a1687404c0cd18ebd17c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 24 Jun 2025 11:55:51 -0400
Subject: [PATCH 10/12] Use i6 instead of i32.

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  4 +--
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  2 +-
 .../AMDGPUToROCDL/transpose_load.mlir         | 32 ++++++++++++-------
 3 files changed, 24 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 11f3d7e4c16f7..71e486c76e1df 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1130,10 +1130,10 @@ struct TransposeLoadOpLowering
       rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(op, resultType,
                                                           srcPtr);
       break;
-    case 32:
+    case 6:
       // To use ds_read_tr6_b96, the load size is vector<3xi32>.
       // TODO: support native 6-bit data types.
-      assert(numElements == 3);
+      assert(numElements == 16);
       rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr6_b96>(op, resultType,
                                                           srcPtr);
       break;
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 747e3622ef3e0..715949e16bf36 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -541,7 +541,7 @@ LogicalResult TransposeLoadOp::verify() {
   // ElementSize -> NumElements
   const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
       {4, 16},
-      {32, 3}, // 6-bit element loads use casted vector<3xi32>
+      {6, 16},
       {8, 8},
       {16, 4},
   };
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
index fd1a278c83991..fae44e5674d7c 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
@@ -31,18 +31,18 @@ func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem :
 
 // -----
 
-// CHECK-LABEL: func @transpose_load_to_rocdl_3xi32
-func.func @transpose_load_to_rocdl_3xi32(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi32, 3>) -> vector<3xi32> {
+// CHECK-LABEL: func @transpose_load_to_rocdl_16xi6
+func.func @transpose_load_to_rocdl_16xi6(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi6, 3>) -> vector<16xi6> {
   // CHECK: rocdl.ds.read.tr6.b96
   // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
-  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi32, 3> -> vector<3xi32>
-  return %0 : vector<3xi32>
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<16xi6>
+  return %0 : vector<16xi6>
 }
 
 // -----
 
-// CHECK-LABEL: func @transpose_load_to_rocdl_i4_memrefxi1
-func.func @transpose_load_to_rocdl_i4_memrefxi1(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi4> {
+// CHECK-LABEL: func @transpose_load_to_rocdl_i4_memrefxi8
+func.func @transpose_load_to_rocdl_i4_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi4> {
   // CHECK: rocdl.ds.read.tr4.b64
   // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
   %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi4>
@@ -51,10 +51,20 @@ func.func @transpose_load_to_rocdl_i4_memrefxi1(%idx1 : index, %idx2 : index, %w
 
 // -----
 
-// CHECK-LABEL: func @transpose_load_to_rocdl_i6_memrefxi1
-func.func @transpose_load_to_rocdl_i6_memrefxi1(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<3xi32> {
+// CHECK-LABEL: func @transpose_load_to_rocdl_i6_memrefxi8
+func.func @transpose_load_to_rocdl_i6_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi6> {
   // CHECK: rocdl.ds.read.tr6.b96
   // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
-  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<3xi32>
-  return %0 : vector<3xi32>
-}
\ No newline at end of file
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi6>
+  return %0 : vector<16xi6>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_i16_memrefxi8
+func.func @transpose_load_to_rocdl_i16_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<4xi16> {
+  // CHECK: rocdl.ds.read.tr16.b64
+  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<4xi16>
+  return %0 : vector<4xi16>
+}

>From 207f2f4b8f89c317abd1d258b90782888f53a6c3 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 24 Jun 2025 12:20:14 -0400
Subject: [PATCH 11/12] Reject subbyte memrefs.

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 10 ++++++++++
 .../AMDGPUToROCDL/transpose_load.mlir         | 20 -------------------
 .../AMDGPUToROCDL/transpose_load_reject.mlir  | 17 ++++++++++++++++
 3 files changed, 27 insertions(+), 20 deletions(-)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 71e486c76e1df..2424d2d3ad343 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1115,6 +1115,16 @@ struct TransposeLoadOpLowering
 
     Location loc = op.getLoc();
     auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+
+    // Elements in subbyte memrefs are stored non-contiguously,
+    // reject if source is sub-byte memref. Use emulated memrefs instead.
+    size_t srcElementSize =
+        srcMemRefType.getElementType().getIntOrFloatBitWidth();
+    if (srcElementSize < 8)
+      return op.emitOpError("Expect source memref to have at least 8 bits "
+                            "element size, got ")
+             << srcElementSize;
+
     auto resultType = cast<VectorType>(op.getResult().getType());
     Value srcPtr =
         getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
index fae44e5674d7c..cccd6f0a51345 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
@@ -21,26 +21,6 @@ func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : m
 
 // -----
 
-// CHECK-LABEL: func @transpose_load_to_rocdl_16xi4
-func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, 3>) -> vector<16xi4> {
-  // CHECK: rocdl.ds.read.tr4.b64
-  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
-  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, 3> -> vector<16xi4>
-  return %0 : vector<16xi4>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_load_to_rocdl_16xi6
-func.func @transpose_load_to_rocdl_16xi6(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi6, 3>) -> vector<16xi6> {
-  // CHECK: rocdl.ds.read.tr6.b96
-  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
-  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<16xi6>
-  return %0 : vector<16xi6>
-}
-
-// -----
-
 // CHECK-LABEL: func @transpose_load_to_rocdl_i4_memrefxi8
 func.func @transpose_load_to_rocdl_i4_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi4> {
   // CHECK: rocdl.ds.read.tr4.b64
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir
new file mode 100644
index 0000000000000..a41051c904ed8
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir
@@ -0,0 +1,17 @@
+// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 2>&1 | FileCheck %s
+
+// -----
+
+func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, 3>) -> vector<16xi4> {
+  // CHECK: memref to have at least 8 bits element size, got 4
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, 3> -> vector<16xi4>
+  return %0 : vector<16xi4>
+}
+
+// -----
+
+func.func @transpose_load_to_rocdl_16xi6(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi6, 3>) -> vector<16xi6> {
+  // CHECK: memref to have at least 8 bits element size, got 6
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<16xi6>
+  return %0 : vector<16xi6>
+}

>From db9b8372f3b382fdac36676dc23297383ffb21ac Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 24 Jun 2025 13:21:59 -0400
Subject: [PATCH 12/12] sub byte and byte elements use i32 as place holder.

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 37 ++++++++++++-------
 .../AMDGPUToROCDL/transpose_load.mlir         | 14 +++++--
 2 files changed, 34 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 2424d2d3ad343..fe1b0c1095180 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -1134,29 +1135,39 @@ struct TransposeLoadOpLowering
     size_t elementTypeSize =
         resultType.getElementType().getIntOrFloatBitWidth();
 
+    // ROCDL transpose load intrinsics return vectors of 32-bit integers.
+    Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
+                                           rewriter.getIntegerType(32));
+    Type llvmResultType = typeConverter->convertType(resultType);
+
     switch (elementTypeSize) {
-    case 4:
+    case 4: {
       assert(numElements == 16);
-      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(op, resultType,
-                                                          srcPtr);
+      auto rocdlOp =
+          rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr);
+      rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
       break;
-    case 6:
-      // To use ds_read_tr6_b96, the load size is vector<3xi32>.
-      // TODO: support native 6-bit data types.
+    }
+    case 6: {
       assert(numElements == 16);
-      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr6_b96>(op, resultType,
-                                                          srcPtr);
+      auto rocdlOp =
+          rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr);
+      rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
       break;
-    case 8:
+    }
+    case 8: {
       assert(numElements == 8);
-      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(op, resultType,
-                                                          srcPtr);
+      auto rocdlOp =
+          rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr);
+      rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
       break;
-    case 16:
+    }
+    case 16: {
       assert(numElements == 4);
-      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, resultType,
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
                                                            srcPtr);
       break;
+    }
     default:
       return op.emitOpError("Unsupported element size for transpose load");
     }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
index cccd6f0a51345..68799098f1d36 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
 // RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx945 2>&1 | FileCheck %s --check-prefix=CHECK-OLD 
 
 // CHECK-LABEL: func @transpose_load_to_rocdl_4xf16
@@ -13,7 +13,9 @@ func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem :
 
 // CHECK-LABEL: func @transpose_load_to_rocdl_8xi8
 func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x128xi8, 3>) -> vector<8xi8> {
-  // CHECK: rocdl.ds.read.tr8.b64
+  // CHECK: %[[RES:.*]] = rocdl.ds.read.tr8.b64
+  // CHECK-SAME: -> vector<2xi32>
+  // CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<8xi8>
   // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
   %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, 3> -> vector<8xi8>
   return %0 : vector<8xi8>
@@ -23,7 +25,9 @@ func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : m
 
 // CHECK-LABEL: func @transpose_load_to_rocdl_i4_memrefxi8
 func.func @transpose_load_to_rocdl_i4_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi4> {
-  // CHECK: rocdl.ds.read.tr4.b64
+  // CHECK: %[[RES:.*]] = rocdl.ds.read.tr4.b64
+  // CHECK-SAME: -> vector<2xi32>
+  // CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<16xi4>
   // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
   %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi4>
   return %0 : vector<16xi4>
@@ -33,7 +37,9 @@ func.func @transpose_load_to_rocdl_i4_memrefxi8(%idx1 : index, %idx2 : index, %w
 
 // CHECK-LABEL: func @transpose_load_to_rocdl_i6_memrefxi8
 func.func @transpose_load_to_rocdl_i6_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi6> {
-  // CHECK: rocdl.ds.read.tr6.b96
+  // CHECK: %[[RES:.*]] = rocdl.ds.read.tr6.b96
+  // CHECK-SAME: -> vector<3xi32>
+  // CHECK-NEXT: llvm.bitcast %[[RES]] : vector<3xi32> to vector<16xi6>
   // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
   %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi6>
   return %0 : vector<16xi6>



More information about the Mlir-commits mailing list