[Mlir-commits] [mlir] dae0ef5 - [MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU (#133498)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 8 06:18:34 PDT 2025
Author: Alan Li
Date: 2025-04-08T09:18:30-04:00
New Revision: dae0ef53a0b99c6c2b74143baee5896e8bc5c8e7
URL: https://github.com/llvm/llvm-project/commit/dae0ef53a0b99c6c2b74143baee5896e8bc5c8e7
DIFF: https://github.com/llvm/llvm-project/commit/dae0ef53a0b99c6c2b74143baee5896e8bc5c8e7.diff
LOG: [MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU (#133498)
Defining a new `amdgpu.global_load` op, which is a thin wrap around
ROCDL `global_load_lds` intrinsic, along with its lowering logics to
`rocdl.global.load.lds`.
Added:
mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
Modified:
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 9cdd961d96ff5..108d7237ff703 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -767,4 +767,40 @@ def AMDGPU_WMMAOp :
let hasVerifier = 1;
}
+def AMDGPU_GatherToLDSOp :
+ AMDGPU_Op<"gather_to_lds", [SameVariadicOperandSize]>,
+ Arguments<(ins
+ Arg<AnyMemRef, "buffer to gather from", [MemRead]>:$src,
+ Variadic<Index>:$srcIndices,
+ Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
+ Variadic<Index>:$dstIndices,
+ TypeAttr:$transferType
+ )>,
+ Results<(outs)> {
+ let summary = "MLIR wrapper for CDNA mfma instructions";
+ let description = [{
+ The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
+
+ Operands:
+ * `$src`: global memory memref to read from.
+ * `$srcIndices`: indices into `$src` to read from for this thread.
+ * `$dst`: LDS memory memref to write to.
+ * `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
+ The elements gathered by the subgroup will be written in order of lane ID will be written
+ into contiguously starting at `$dst[$dstIndices]`.
+ * `$transferType`: type of the data to be transferred by each thread. This is used to determine
+ the size of the data to be transferred and the number of threads in the subgroup.
+ The transfer type must be a scalar type or a vector type with a single element type.
+
+ The `$dst`, along with its indices, points to the memory location the subgroup of this thread
+ will write to.
+
+ Note: only enabled for gfx942 and later.
+ }];
+ let assemblyFormat = [{
+ $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // AMDGPU
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 56d40d6d123bf..5f697bdeef566 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1010,6 +1010,55 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
+ GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset < kGfx942)
+ return op.emitOpError("chipset not supported");
+
+ Location loc = op.getLoc();
+
+ auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+ auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
+
+ // TODO: instead of only transfering one element per thread, we could
+ // augment it to transfer multiple elements per thread by issuing multiple
+ // `global_load_lds` instructions.
+ Type transferType = op.getTransferType();
+ size_t loadWidth = [&]() -> size_t {
+ if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
+ return transferVectorType.getNumElements() *
+ (transferVectorType.getElementTypeBitWidth() / 8);
+ } else {
+ return transferType.getIntOrFloatBitWidth() / 8;
+ }
+ }();
+
+ // Currently only 1, 2, and 4 byte loads are supported.
+ if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
+ return op.emitOpError("chipset unsupported element size");
+
+ Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
+ (adaptor.getSrcIndices()), rewriter);
+ Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
+ (adaptor.getDstIndices()), rewriter);
+
+ rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
+ op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
+ createI32Constant(rewriter, loc, 0),
+ createI32Constant(rewriter, loc, 0), ArrayAttr{}, ArrayAttr{},
+ ArrayAttr{});
+
+ return success();
+ }
+};
+
namespace {
struct ExtPackedFp8OpLowering final
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
@@ -1393,6 +1442,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
- chipset);
+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+ GatherToLDSOpLowering>(converter, chipset);
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 1e482515a4ee0..7f286f938ee60 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
@@ -24,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/DerivedTypes.h"
#include <limits>
#include <optional>
@@ -112,21 +114,31 @@ LogicalResult FatRawBufferCastOp::verify() {
return success();
}
+static bool hasGlobalMemorySpace(Attribute memorySpace) {
+ if (!memorySpace)
+ return true;
+ if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
+ return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
+ if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+ return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+ return false;
+}
+
+static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
+ if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
+ return intMemorySpace.getInt() == 3;
+ if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+ return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// RawBuffer*Op
//===----------------------------------------------------------------------===//
template <typename T>
static LogicalResult verifyRawBufferOp(T &op) {
MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
- Attribute memorySpace = bufferType.getMemorySpace();
- bool isGlobal = false;
- if (!memorySpace)
- isGlobal = true;
- else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
- isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
- else if (auto gpuMemorySpace =
- llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
- isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+ bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
if (!isGlobal)
return op.emitOpError(
@@ -461,6 +473,40 @@ LogicalResult DPPOp::verify() {
return success();
}
+LogicalResult GatherToLDSOp::verify() {
+ MemRefType srcType = cast<MemRefType>(getSrc().getType());
+ MemRefType dstType = cast<MemRefType>(getDst().getType());
+
+ if (!memref::isStaticShapeAndContiguousRowMajor(dstType))
+ return emitOpError(
+ "destination types must have static shape and contiguous");
+
+ auto elemType = srcType.getElementType();
+ // Check $src and $dst element types are the same.
+ if (elemType != dstType.getElementType())
+ return emitOpError("source and destination element types must match");
+
+ // copy type sizes should be 1, 2, or 4 bytes.
+ auto transferType = getTransferType();
+ size_t transferSize;
+ if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
+ transferSize = vectorTransfer.getNumElements() *
+ vectorTransfer.getElementTypeBitWidth();
+ } else {
+ transferSize = transferType.getIntOrFloatBitWidth();
+ }
+ if (transferSize != 8 && transferSize != 16 && transferSize != 32)
+ return emitOpError("Transfering type size must be 8, 16, or 32 bits");
+
+ if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
+ return emitOpError("source memory address space must be Global");
+
+ if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
+ return emitOpError("destination memory address space must be Workgroup");
+
+ return success();
+}
+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
new file mode 100644
index 0000000000000..b1c16bd5db079
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -0,0 +1,143 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s
+
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+
+// CHECK-LABEL: func @global_load_to_rocdl_f32
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
+func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c32 = arith.constant 32 : index
+ %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
+ // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+ // CHECK: %[[C12:.*]] = arith.constant 12 : index
+ // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+ // CHECK: %[[ALLOC:.*]] = memref.alloc()
+ // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
+ // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
+
+ // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+ // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+ // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+ // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
+ // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+
+ // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+ // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+ // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+ // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
+ // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]]
+ amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+ : f32, memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+ func.return
+}
+
+// CHECK-LABEL: func @global_load_to_rocdl_i8
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi8, 1>)
+func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrspace>) {
+ // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+ // CHECK: %[[C12:.*]] = arith.constant 12 : index
+ // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+ // CHECK: %[[ALLOC:.*]] = memref.alloc()
+ // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
+ // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
+
+ // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+ // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+ // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+ // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
+ // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+
+ // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+ // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+ // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+ // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
+ // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C1]]
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c32 = arith.constant 32 : index
+ %alloc = memref.alloc() : memref<64x64xi8, #gpu_lds_addrspace>
+ amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+ : i8, memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace>
+ func.return
+}
+
+// CHECK-LABEL: func @global_load_to_rocdl_vec
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi16, 1>)
+func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_addrspace>) {
+ // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+ // CHECK: %[[C12:.*]] = arith.constant 12 : index
+ // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+ // CHECK: %[[ALLOC:.*]] = memref.alloc()
+ // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
+ // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
+
+ // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+ // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+ // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+ // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
+ // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+
+ // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+ // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+ // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+ // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
+ // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]]
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c32 = arith.constant 32 : index
+ %alloc = memref.alloc() : memref<64x128xi16, #gpu_lds_addrspace>
+ amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+ : vector<2 x i16>, memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace>
+ func.return
+}
+
+
+// CHECK-LABEL: func @global_load_to_rocdl_dynamic_indices
+// CHECK-SAME: (%[[ARG0:.*]]: memref<512xi32, 1>, %[[SRC_IDX:.*]]: index, %[[DST_IDX:.*]]: index)
+func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_global_addrspace>, %src_idx : index, %dst_idx : index) {
+ // CHECK: %[[DSTIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[DST_IDX]]
+ // CHECK: %[[SRCIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC_IDX]]
+ // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+ // CHECK: %[[ALLOC:.*]] = memref.alloc()
+ // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
+ // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
+ // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]]
+ // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+ // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX_CAST]]]
+ // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]]
+ %alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
+ %c0 = arith.constant 0 : index
+ amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0]
+ : i32, memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
+ func.return
+}
More information about the Mlir-commits
mailing list