[Mlir-commits] [mlir] 3f3282c - [AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads. (#145395)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 25 19:58:17 PDT 2025


Author: Alan Li
Date: 2025-06-25T22:58:14-04:00
New Revision: 3f3282cee87f307afe58c899f03df3a882846290

URL: https://github.com/llvm/llvm-project/commit/3f3282cee87f307afe58c899f03df3a882846290
DIFF: https://github.com/llvm/llvm-project/commit/3f3282cee87f307afe58c899f03df3a882846290.diff

LOG: [AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads. (#145395)

* 1-to-1 mapping wrapper op.
* Direct lowering from AMDGPU wrapper to ROCDL intrinsics.

Added: 
    mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
    mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir

Modified: 
    mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
    mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
    mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
    mlir/test/Dialect/AMDGPU/invalid.mlir
    mlir/test/Dialect/AMDGPU/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d58558ac32884..eadb5d9326798 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -898,6 +898,40 @@ def AMDGPU_GatherToLDSOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_TransposeLoadOp :
+    AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
+    Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
+    Results<(outs AnyTypeOf<[AnyVectorOfNonZeroRank]>:$result)> {
+  let summary = "MLIR wrapper for CDNA Transpose Load instructions";
+  let description = [{
+    The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
+    The transpose load op represents a subgroup load from LDS memory,
+    where the subgroup of threads collectively reads a matrix from the source
+    memref, with each thread reading a vector of the matrix, and gets a transposed matrix
+    in as the result. That is, each thread reads a vector of the col-major matrix at 
diff erent
+    indices, and the thread's read result is a vector of the corresponding row of the transposed
+    matrix.
+
+    This op is a direct wrapper around the ROCDL `ds_read_tr` family intrinsics. Please refer
+    to the CDNA4 ISA documentation for more details about its exact semantics.
+
+    Format example:
+    ```
+    %0 = amdgpu.transpose_load %src[%srcIndices] : memref<128x256xf16> -> vector<4xf16>
+    ```
+    Operands:
+    * `$src`: LDS memref to read from.
+    * `$srcIndices`: indices into `$src` to read from for this thread.
+    * `$result`: target register this transpose load instruction will write to.
+
+    Note: Lowering is only supported on gfx950 and up.
+  }];
+  let assemblyFormat = [{
+    $src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_ScaledMFMAOp :
     AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
                         Pure]>,

diff  --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 700563460f525..910fe1b1d93c1 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1100,6 +1100,81 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct TransposeLoadOpLowering
+    : public ConvertOpToLLVMPattern<TransposeLoadOp> {
+  TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset != kGfx950)
+      return op.emitOpError("Non-gfx950 chipset not supported");
+
+    Location loc = op.getLoc();
+    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+
+    // Elements in subbyte memrefs are stored non-contiguously,
+    // reject if source is sub-byte memref. Use emulated memrefs instead.
+    size_t srcElementSize =
+        srcMemRefType.getElementType().getIntOrFloatBitWidth();
+    if (srcElementSize < 8)
+      return op.emitOpError("Expect source memref to have at least 8 bits "
+                            "element size, got ")
+             << srcElementSize;
+
+    auto resultType = cast<VectorType>(op.getResult().getType());
+    Value srcPtr =
+        getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+                             (adaptor.getSrcIndices()));
+
+    size_t numElements = resultType.getNumElements();
+    size_t elementTypeSize =
+        resultType.getElementType().getIntOrFloatBitWidth();
+
+    // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
+    // the element size is smaller than 16 bits.
+    Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
+                                           rewriter.getIntegerType(32));
+    Type llvmResultType = typeConverter->convertType(resultType);
+
+    switch (elementTypeSize) {
+    case 4: {
+      assert(numElements == 16);
+      auto rocdlOp =
+          rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr);
+      rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
+      break;
+    }
+    case 6: {
+      assert(numElements == 16);
+      auto rocdlOp =
+          rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr);
+      rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
+      break;
+    }
+    case 8: {
+      assert(numElements == 8);
+      auto rocdlOp =
+          rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr);
+      rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
+      break;
+    }
+    case 16: {
+      assert(numElements == 4);
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
+                                                           srcPtr);
+      break;
+    }
+    default:
+      return op.emitOpError("Unsupported element size for transpose load");
+    }
+    return success();
+  }
+};
+
 struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
   GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1824,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
            ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
            PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
-           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
-                                                                 chipset);
+           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+           TransposeLoadOpLowering>(converter, chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 0d0add3094666..4613d14461969 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 #include <limits>
@@ -524,6 +525,39 @@ LogicalResult GatherToLDSOp::verify() {
   return success();
 }
 
+LogicalResult TransposeLoadOp::verify() {
+  MemRefType srcType = cast<MemRefType>(getSrc().getType());
+
+  if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
+    return emitOpError("source memory address space must be Workgroup");
+
+  auto transferType = cast<VectorType>(getType());
+  size_t numElements = transferType.getNumElements();
+  size_t elementTypeSize =
+      transferType.getElementType().getIntOrFloatBitWidth();
+
+  // ElementSize -> NumElements
+  const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
+      {4, 16},
+      {6, 16},
+      {8, 8},
+      {16, 4},
+  };
+
+  auto validNumElems = KValidLoadSizeMap.find(elementTypeSize);
+  if (validNumElems == KValidLoadSizeMap.end()) {
+    return emitOpError("Unsupported element type size for transpose load: ")
+           << elementTypeSize << " bits";
+  }
+  if (numElements != validNumElems->second) {
+    return emitOpError(
+               "Transferring type size mismatch: expected num of elements: ")
+           << validNumElems->second;
+  }
+
+  return success();
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES

diff  --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
new file mode 100644
index 0000000000000..68799098f1d36
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx945 2>&1 | FileCheck %s --check-prefix=CHECK-OLD 
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_4xf16
+func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem : memref<128x72xf16, 3>) -> vector<4xf16> {
+  // CHECK: rocdl.ds.read.tr16.b64
+  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x72xf16, 3> -> vector<4xf16>
+  return %0 : vector<4xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_8xi8
+func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x128xi8, 3>) -> vector<8xi8> {
+  // CHECK: %[[RES:.*]] = rocdl.ds.read.tr8.b64
+  // CHECK-SAME: -> vector<2xi32>
+  // CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<8xi8>
+  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, 3> -> vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_i4_memrefxi8
+func.func @transpose_load_to_rocdl_i4_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi4> {
+  // CHECK: %[[RES:.*]] = rocdl.ds.read.tr4.b64
+  // CHECK-SAME: -> vector<2xi32>
+  // CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<16xi4>
+  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi4>
+  return %0 : vector<16xi4>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_i6_memrefxi8
+func.func @transpose_load_to_rocdl_i6_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi6> {
+  // CHECK: %[[RES:.*]] = rocdl.ds.read.tr6.b96
+  // CHECK-SAME: -> vector<3xi32>
+  // CHECK-NEXT: llvm.bitcast %[[RES]] : vector<3xi32> to vector<16xi6>
+  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi6>
+  return %0 : vector<16xi6>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_i16_memrefxi8
+func.func @transpose_load_to_rocdl_i16_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<4xi16> {
+  // CHECK: rocdl.ds.read.tr16.b64
+  // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<4xi16>
+  return %0 : vector<4xi16>
+}

diff  --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir
new file mode 100644
index 0000000000000..a41051c904ed8
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir
@@ -0,0 +1,17 @@
+// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 2>&1 | FileCheck %s
+
+// -----
+
+func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, 3>) -> vector<16xi4> {
+  // CHECK: memref to have at least 8 bits element size, got 4
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, 3> -> vector<16xi4>
+  return %0 : vector<16xi4>
+}
+
+// -----
+
+func.func @transpose_load_to_rocdl_16xi6(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi6, 3>) -> vector<16xi6> {
+  // CHECK: memref to have at least 8 bits element size, got 6
+  %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<16xi6>
+  return %0 : vector<16xi6>
+}

diff  --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 73306ba6b3f93..6d55583f8bc7c 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -166,3 +166,59 @@ func.func @swizzle_scalable_vec(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
   %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<[4]xf32>
   func.return %0 : vector<[4]xf32>
 }
+
+// -----
+
+func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> {
+  // expected-error at +1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}}
+  %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16>
+  func.return %0 : vector<4xf16>
+}
+
+// -----
+
+func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> {
+  // expected-error at +1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}}
+  %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16>
+  func.return %0 : vector<4xf16>
+}
+
+// -----
+
+func.func @transpose_load_elem_f32(%idx1 : index, %idx2 : index, %mem : memref<128x32xf32, 3>) -> vector<4xf32> {
+  // expected-error at +1 {{'amdgpu.transpose_load' op Unsupported element type size for transpose load: 32 bits}}
+  %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf32, 3> -> vector<4xf32>
+  func.return %0 : vector<4xf32>
+}
+
+// -----
+
+func.func @transpose_load_vector_size_f16(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<2xf16> {
+  // expected-error at +1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 4}}
+  %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<2xf16>
+  func.return %0 : vector<2xf16>
+}
+
+// -----
+
+func.func @transpose_load_vector_size_i4(%idx1 : index, %idx2 : index, %mem : memref<128x32xi4, 3>) -> vector<20xi4> {
+  // expected-error at +1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 16}}
+  %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi4, 3> -> vector<20xi4>
+  func.return %0 : vector<20xi4>
+}
+
+// -----
+
+func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi8, 3>) -> vector<20xi8> {
+  // expected-error at +1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 8}}
+  %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<20xi8>
+  func.return %0 : vector<20xi8>
+}
+
+// -----
+
+func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi6, 3>) -> vector<8xi6> {
+  // expected-error at +1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 16}}
+  %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<8xi6>
+  func.return %0 : vector<8xi6>
+}

diff  --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 6c3ffb575f7c2..51f3bbd9ae45c 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -486,3 +486,10 @@ func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : v
   %0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : f8E8M0FNU, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
   func.return %0 : vector<16xf32>
 }
+
+// CHECK-LABEL: func @transpose_load
+func.func @transpose_load(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<4xf16> {
+  // CHECK: amdgpu.transpose_load
+  %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<4xf16>
+  func.return %0 : vector<4xf16>
+}


        


More information about the Mlir-commits mailing list