[Mlir-commits] [mlir] [MLIR][AMDGPU] Add amdgpu.global_transpose_load op for RDNA4 global memory transpose loads (PR #195287)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 1 09:31:33 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Nirvedh Meshram (nirvedhmeshram)

<details>
<summary>Changes</summary>

Adds a new `amdgpu.global_transpose_load` op to the AMDGPU dialect that wraps the `global_load_tr` family of instructions introduced in RDNA4 (gfx1250+). Each thread reads a column of a matrix from global memory and receives the corresponding transposed row in its result register.

The op is kept separate from the existing `amdgpu.transpose_load` (which targets LDS via `ds_read_tr` on gfx950+) because the two variants target different GPU architecture families, have different chipset requirements, and differ in their valid (element size, num elements) combinations — in particular the 16-bit case produces a 128-bit (8-element) result via `global_load_tr.b128` rather than the 64-bit (4-element) result from `ds_read_tr16.b64`.

Lowering to the existing ROCDL `global.load.tr.b{64,128}` intrinsics added for gfx1250+.

---
Full diff: https://github.com/llvm/llvm-project/pull/195287.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td (+39) 
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+61) 
- (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp (+35) 
- (added) mlir/test/Conversion/AMDGPUToROCDL/global_transpose_load.mlir (+56) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
index 4112ea281bb96..edc68d8c0e590 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
@@ -1449,6 +1449,45 @@ def AMDGPU_TransposeLoadOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_GlobalTransposeLoadOp :
+    AMDGPU_Op<"global_transpose_load", [SameVariadicOperandSize]>,
+    Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src,
+                      Variadic<Index>:$srcIndices)>,
+    Results<(outs AnyTypeOf<[AnyVectorOfNonZeroRank]>:$result)> {
+  let summary = "MLIR wrapper for RDNA4 global memory transpose load instructions";
+  let description = [{
+    The `amdgpu.global_transpose_load` op is a wrapper around the
+    `global_load_tr` family of instructions introduced in RDNA4 (gfx1250+).
+
+    Each thread reads a column of a matrix stored in global memory and receives
+    the corresponding row of the transposed matrix in its result register.
+    The subgroup collectively performs a transpose of the tile.
+
+    This op is a direct wrapper around the ROCDL `global.load.tr` family
+    intrinsics. Refer to the RDNA4 ISA documentation for exact semantics.
+
+    Format example:
+    ```
+    %0 = amdgpu.global_transpose_load %src[%i, %j]
+           : memref<128x256xf16, #gpu.address_space<global>> -> vector<8xf16>
+    ```
+    Operands:
+    * `$src`: Global address space memref to read from.
+    * `$srcIndices`: indices into `$src` for this thread.
+    * `$result`: register this transpose load instruction writes to.
+
+    Valid (element bits, num elements) pairs:
+    * (8, 8)   -> global_load_tr_b64
+    * (16, 8)  -> global_load_tr_b128
+
+    Note: Lowering is only supported on gfx1250 and up.
+  }];
+  let assemblyFormat = [{
+    $src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_ScaledMFMAOp :
     AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
                         Pure]>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 8464d1e29f0aa..5844a845bd9e6 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2226,6 +2226,66 @@ struct TransposeLoadOpLowering
   }
 };
 
+struct GlobalTransposeLoadOpLowering
+    : public ConvertOpToLLVMPattern<GlobalTransposeLoadOp> {
+  GlobalTransposeLoadOpLowering(const LLVMTypeConverter &converter,
+                                Chipset chipset)
+      : ConvertOpToLLVMPattern<GlobalTransposeLoadOp>(converter),
+        chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(GlobalTransposeLoadOp op,
+                  GlobalTransposeLoadOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx1250)
+      return op.emitOpError(
+          "global_transpose_load is only supported on gfx1250+");
+
+    Location loc = op.getLoc();
+    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+    auto resultType = cast<VectorType>(op.getResult().getType());
+
+    Value srcPtr =
+        getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+                             adaptor.getSrcIndices());
+
+    size_t numElements = resultType.getNumElements();
+    size_t elementTypeSize =
+        resultType.getElementType().getIntOrFloatBitWidth();
+
+    // ROCDL global transpose load intrinsics return vectors of i32 for
+    // sub-16-bit elements, matching the LDS lowering convention.
+    Type rocdlResultType =
+        elementTypeSize < 16
+            ? VectorType::get((numElements * elementTypeSize) / 32,
+                              rewriter.getIntegerType(32))
+            : typeConverter->convertType(resultType);
+    Type llvmResultType = typeConverter->convertType(resultType);
+
+    switch (elementTypeSize) {
+    case 8: {
+      assert(numElements == 8);
+      auto rocdlOp = ROCDL::GlobalLoadTr8_B64::create(rewriter, loc,
+                                                      rocdlResultType, srcPtr);
+      rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
+      break;
+    }
+    case 16: {
+      assert(numElements == 8);
+      rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadTr8_B128>(op, llvmResultType,
+                                                             srcPtr);
+      break;
+    }
+    default:
+      return op.emitOpError(
+          "unsupported element size for global transpose load");
+    }
+    return success();
+  }
+};
+
 struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
   GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -4408,6 +4468,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
            PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
            GlobalLoadAsyncToLDSOpLowering, TransposeLoadOpLowering,
+           GlobalTransposeLoadOpLowering,
            AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
            AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
            AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
index 2f6f59194fba3..7d9bccd899a69 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
@@ -1079,6 +1079,41 @@ LogicalResult TransposeLoadOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// GlobalTransposeLoadOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GlobalTransposeLoadOp::verify() {
+  MemRefType srcType = cast<MemRefType>(getSrc().getType());
+
+  if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
+    return emitOpError("source memory address space must be Global");
+
+  auto resultType = cast<VectorType>(getType());
+  size_t numElements = resultType.getNumElements();
+  size_t elementTypeSize =
+      resultType.getElementType().getIntOrFloatBitWidth();
+
+  // ElementSize -> NumElements (matches ISA-documented global_load_tr variants)
+  const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
+      {8, 8},   // global_load_tr_b64
+      {16, 8},  // global_load_tr_b128
+  };
+
+  auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
+  if (validNumElems == kValidLoadSizeMap.end())
+    return emitOpError(
+               "unsupported element type size for global transpose load: ")
+           << elementTypeSize << " bits";
+
+  if (numElements != validNumElems->second)
+    return emitOpError(
+               "transferring type size mismatch: expected num of elements: ")
+           << validNumElems->second;
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // MakeDmaBaseOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/global_transpose_load.mlir b/mlir/test/Conversion/AMDGPUToROCDL/global_transpose_load.mlir
new file mode 100644
index 0000000000000..159706ccb53fd
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/global_transpose_load.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics -convert-amdgpu-to-rocdl=chipset=gfx1250 | FileCheck %s
+// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx942 2>&1 | FileCheck %s --check-prefix=CHECK-OLD
+
+// CHECK-LABEL: func @global_transpose_load_8xf16
+func.func @global_transpose_load_8xf16(%i : index, %j : index,
+    %src : memref<128x256xf16, #gpu.address_space<global>>) -> vector<8xf16> {
+  // CHECK: rocdl.global.load.tr.b128
+  // CHECK-OLD: error: 'amdgpu.global_transpose_load' op global_transpose_load is only supported on gfx1250+
+  %0 = amdgpu.global_transpose_load %src[%i, %j]
+         : memref<128x256xf16, #gpu.address_space<global>> -> vector<8xf16>
+  return %0 : vector<8xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @global_transpose_load_8xi8
+func.func @global_transpose_load_8xi8(%i : index, %j : index,
+    %src : memref<128x256xi8, #gpu.address_space<global>>) -> vector<8xi8> {
+  // CHECK: %[[RES:.*]] = rocdl.global.load.tr.b64
+  // CHECK-SAME: -> vector<2xi32>
+  // CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<8xi8>
+  // CHECK-OLD: error: 'amdgpu.global_transpose_load' op global_transpose_load is only supported on gfx1250+
+  %0 = amdgpu.global_transpose_load %src[%i, %j]
+         : memref<128x256xi8, #gpu.address_space<global>> -> vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+// -----
+
+func.func @global_transpose_load_wrong_addrspace(%i : index, %j : index,
+    %src : memref<128x256xf16, 3>) -> vector<8xf16> {
+  // expected-error at +1 {{'amdgpu.global_transpose_load' op source memory address space must be Global}}
+  %0 = amdgpu.global_transpose_load %src[%i, %j]
+         : memref<128x256xf16, 3> -> vector<8xf16>
+  return %0 : vector<8xf16>
+}
+
+// -----
+
+func.func @global_transpose_load_unsupported_f32(%i : index, %j : index,
+    %src : memref<128x256xf32, #gpu.address_space<global>>) -> vector<8xf32> {
+  // expected-error at +1 {{'amdgpu.global_transpose_load' op unsupported element type size for global transpose load: 32 bits}}
+  %0 = amdgpu.global_transpose_load %src[%i, %j]
+         : memref<128x256xf32, #gpu.address_space<global>> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @global_transpose_load_wrong_num_elements(%i : index, %j : index,
+    %src : memref<128x256xf16, #gpu.address_space<global>>) -> vector<4xf16> {
+  // expected-error at +1 {{'amdgpu.global_transpose_load' op transferring type size mismatch: expected num of elements: 8}}
+  %0 = amdgpu.global_transpose_load %src[%i, %j]
+         : memref<128x256xf16, #gpu.address_space<global>> -> vector<4xf16>
+  return %0 : vector<4xf16>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/195287


More information about the Mlir-commits mailing list