[Mlir-commits] [mlir] 3f3282c - [AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads. (#145395)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 25 19:58:17 PDT 2025
Author: Alan Li
Date: 2025-06-25T22:58:14-04:00
New Revision: 3f3282cee87f307afe58c899f03df3a882846290
URL: https://github.com/llvm/llvm-project/commit/3f3282cee87f307afe58c899f03df3a882846290
DIFF: https://github.com/llvm/llvm-project/commit/3f3282cee87f307afe58c899f03df3a882846290.diff
LOG: [AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads. (#145395)
* 1-to-1 mapping wrapper op.
* Direct lowering from AMDGPU wrapper to ROCDL intrinsics.
Added:
mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir
Modified:
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
mlir/test/Dialect/AMDGPU/invalid.mlir
mlir/test/Dialect/AMDGPU/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d58558ac32884..eadb5d9326798 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -898,6 +898,40 @@ def AMDGPU_GatherToLDSOp :
let hasVerifier = 1;
}
+def AMDGPU_TransposeLoadOp :
+ AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
+ Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
+ Results<(outs AnyTypeOf<[AnyVectorOfNonZeroRank]>:$result)> {
+ let summary = "MLIR wrapper for CDNA Transpose Load instructions";
+ let description = [{
+ The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
+ The transpose load op represents a subgroup load from LDS memory,
+ where the subgroup of threads collectively reads a matrix from the source
+ memref, with each thread reading a vector of the matrix, and gets a transposed matrix
+ in as the result. That is, each thread reads a vector of the col-major matrix at
diff erent
+ indices, and the thread's read result is a vector of the corresponding row of the transposed
+ matrix.
+
+ This op is a direct wrapper around the ROCDL `ds_read_tr` family intrinsics. Please refer
+ to the CDNA4 ISA documentation for more details about its exact semantics.
+
+ Format example:
+ ```
+ %0 = amdgpu.transpose_load %src[%srcIndices] : memref<128x256xf16> -> vector<4xf16>
+ ```
+ Operands:
+ * `$src`: LDS memref to read from.
+ * `$srcIndices`: indices into `$src` to read from for this thread.
+ * `$result`: target register this transpose load instruction will write to.
+
+ Note: Lowering is only supported on gfx950 and up.
+ }];
+ let assemblyFormat = [{
+ $src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_ScaledMFMAOp :
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 700563460f525..910fe1b1d93c1 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1100,6 +1100,81 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct TransposeLoadOpLowering
+ : public ConvertOpToLLVMPattern<TransposeLoadOp> {
+ TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset != kGfx950)
+ return op.emitOpError("Non-gfx950 chipset not supported");
+
+ Location loc = op.getLoc();
+ auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+
+ // Elements in subbyte memrefs are stored non-contiguously,
+ // reject if source is sub-byte memref. Use emulated memrefs instead.
+ size_t srcElementSize =
+ srcMemRefType.getElementType().getIntOrFloatBitWidth();
+ if (srcElementSize < 8)
+ return op.emitOpError("Expect source memref to have at least 8 bits "
+ "element size, got ")
+ << srcElementSize;
+
+ auto resultType = cast<VectorType>(op.getResult().getType());
+ Value srcPtr =
+ getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+ (adaptor.getSrcIndices()));
+
+ size_t numElements = resultType.getNumElements();
+ size_t elementTypeSize =
+ resultType.getElementType().getIntOrFloatBitWidth();
+
+ // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
+ // the element size is smaller than 16 bits.
+ Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
+ rewriter.getIntegerType(32));
+ Type llvmResultType = typeConverter->convertType(resultType);
+
+ switch (elementTypeSize) {
+ case 4: {
+ assert(numElements == 16);
+ auto rocdlOp =
+ rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr);
+ rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
+ break;
+ }
+ case 6: {
+ assert(numElements == 16);
+ auto rocdlOp =
+ rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr);
+ rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
+ break;
+ }
+ case 8: {
+ assert(numElements == 8);
+ auto rocdlOp =
+ rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr);
+ rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
+ break;
+ }
+ case 16: {
+ assert(numElements == 4);
+ rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
+ srcPtr);
+ break;
+ }
+ default:
+ return op.emitOpError("Unsupported element size for transpose load");
+ }
+ return success();
+ }
+};
+
struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1824,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
- PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
- chipset);
+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+ TransposeLoadOpLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 0d0add3094666..4613d14461969 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include <limits>
@@ -524,6 +525,39 @@ LogicalResult GatherToLDSOp::verify() {
return success();
}
+LogicalResult TransposeLoadOp::verify() {
+ MemRefType srcType = cast<MemRefType>(getSrc().getType());
+
+ if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
+ return emitOpError("source memory address space must be Workgroup");
+
+ auto transferType = cast<VectorType>(getType());
+ size_t numElements = transferType.getNumElements();
+ size_t elementTypeSize =
+ transferType.getElementType().getIntOrFloatBitWidth();
+
+ // ElementSize -> NumElements
+ const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
+ {4, 16},
+ {6, 16},
+ {8, 8},
+ {16, 4},
+ };
+
+ auto validNumElems = KValidLoadSizeMap.find(elementTypeSize);
+ if (validNumElems == KValidLoadSizeMap.end()) {
+ return emitOpError("Unsupported element type size for transpose load: ")
+ << elementTypeSize << " bits";
+ }
+ if (numElements != validNumElems->second) {
+ return emitOpError(
+ "Transferring type size mismatch: expected num of elements: ")
+ << validNumElems->second;
+ }
+
+ return success();
+}
+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
new file mode 100644
index 0000000000000..68799098f1d36
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx945 2>&1 | FileCheck %s --check-prefix=CHECK-OLD
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_4xf16
+func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem : memref<128x72xf16, 3>) -> vector<4xf16> {
+ // CHECK: rocdl.ds.read.tr16.b64
+ // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+ %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x72xf16, 3> -> vector<4xf16>
+ return %0 : vector<4xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_8xi8
+func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x128xi8, 3>) -> vector<8xi8> {
+ // CHECK: %[[RES:.*]] = rocdl.ds.read.tr8.b64
+ // CHECK-SAME: -> vector<2xi32>
+ // CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<8xi8>
+ // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+ %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, 3> -> vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_i4_memrefxi8
+func.func @transpose_load_to_rocdl_i4_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi4> {
+ // CHECK: %[[RES:.*]] = rocdl.ds.read.tr4.b64
+ // CHECK-SAME: -> vector<2xi32>
+ // CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<16xi4>
+ // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+ %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi4>
+ return %0 : vector<16xi4>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_i6_memrefxi8
+func.func @transpose_load_to_rocdl_i6_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi6> {
+ // CHECK: %[[RES:.*]] = rocdl.ds.read.tr6.b96
+ // CHECK-SAME: -> vector<3xi32>
+ // CHECK-NEXT: llvm.bitcast %[[RES]] : vector<3xi32> to vector<16xi6>
+ // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+ %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi6>
+ return %0 : vector<16xi6>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_load_to_rocdl_i16_memrefxi8
+func.func @transpose_load_to_rocdl_i16_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<4xi16> {
+ // CHECK: rocdl.ds.read.tr16.b64
+ // CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
+ %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<4xi16>
+ return %0 : vector<4xi16>
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir
new file mode 100644
index 0000000000000..a41051c904ed8
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir
@@ -0,0 +1,17 @@
+// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 2>&1 | FileCheck %s
+
+// -----
+
+func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, 3>) -> vector<16xi4> {
+ // CHECK: memref to have at least 8 bits element size, got 4
+ %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, 3> -> vector<16xi4>
+ return %0 : vector<16xi4>
+}
+
+// -----
+
+func.func @transpose_load_to_rocdl_16xi6(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi6, 3>) -> vector<16xi6> {
+ // CHECK: memref to have at least 8 bits element size, got 6
+ %0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<16xi6>
+ return %0 : vector<16xi6>
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 73306ba6b3f93..6d55583f8bc7c 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -166,3 +166,59 @@ func.func @swizzle_scalable_vec(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<[4]xf32>
func.return %0 : vector<[4]xf32>
}
+
+// -----
+
+func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> {
+ // expected-error at +1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}}
+ %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16>
+ func.return %0 : vector<4xf16>
+}
+
+// -----
+
+func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> {
+ // expected-error at +1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}}
+ %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16>
+ func.return %0 : vector<4xf16>
+}
+
+// -----
+
+func.func @transpose_load_elem_f32(%idx1 : index, %idx2 : index, %mem : memref<128x32xf32, 3>) -> vector<4xf32> {
+ // expected-error at +1 {{'amdgpu.transpose_load' op Unsupported element type size for transpose load: 32 bits}}
+ %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf32, 3> -> vector<4xf32>
+ func.return %0 : vector<4xf32>
+}
+
+// -----
+
+func.func @transpose_load_vector_size_f16(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<2xf16> {
+ // expected-error at +1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 4}}
+ %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<2xf16>
+ func.return %0 : vector<2xf16>
+}
+
+// -----
+
+func.func @transpose_load_vector_size_i4(%idx1 : index, %idx2 : index, %mem : memref<128x32xi4, 3>) -> vector<20xi4> {
+ // expected-error at +1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 16}}
+ %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi4, 3> -> vector<20xi4>
+ func.return %0 : vector<20xi4>
+}
+
+// -----
+
+func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi8, 3>) -> vector<20xi8> {
+ // expected-error at +1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 8}}
+ %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<20xi8>
+ func.return %0 : vector<20xi8>
+}
+
+// -----
+
+func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi6, 3>) -> vector<8xi6> {
+ // expected-error at +1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 16}}
+ %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<8xi6>
+ func.return %0 : vector<8xi6>
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 6c3ffb575f7c2..51f3bbd9ae45c 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -486,3 +486,10 @@ func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : v
%0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : f8E8M0FNU, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
func.return %0 : vector<16xf32>
}
+
+// CHECK-LABEL: func @transpose_load
+func.func @transpose_load(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<4xf16> {
+ // CHECK: amdgpu.transpose_load
+ %0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<4xf16>
+ func.return %0 : vector<4xf16>
+}
More information about the Mlir-commits
mailing list