[Mlir-commits] [mlir] [amdgpu][mlir] implement amdgpu.cluster_load_async_to_lds (PR #195410)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat May 2 22:39:58 PDT 2026
https://github.com/efric updated https://github.com/llvm/llvm-project/pull/195410
>From 0545b33b9c6fe66f4bd6147deed3930affd2588b Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 1 May 2026 21:20:54 -0700
Subject: [PATCH 1/2] cluster load to lds
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
.../mlir/Dialect/AMDGPU/IR/AMDGPUOps.td | 44 ++++++++-
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 84 +++++++++++++++--
mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp | 90 ++++++++++++++----
.../Conversion/AMDGPUToROCDL/gfx1250.mlir | 64 +++++++++++++
mlir/test/Dialect/AMDGPU/invalid.mlir | 93 +++++++++++++++++++
mlir/test/Dialect/AMDGPU/ops.mlir | 36 +++++++
6 files changed, 384 insertions(+), 27 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
index 4112ea281bb96..e6b851153ce92 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
@@ -1382,7 +1382,8 @@ def AMDGPU_GlobalLoadAsyncToLDSOp :
AMDGPU wrapper for `global.load.async.to.lds` instructions, which performs
asynchronous load of data from global memory into LDS while bypassing VGPRs.
- * `$src`: global memory memref to read from (global addrspace only, no fat buffer).
+ * `$src`: global memory memref to read from (global addrspace only, no
+ fat buffer).
* `$srcIndices`: indices into `$src` for this thread's global read location.
* `$dst`: LDS memref to write to (workgroup addrspace).
* `$dstIndices`: indices into `$dst` for this thread's LDS write location.
@@ -1415,6 +1416,47 @@ def AMDGPU_GlobalLoadAsyncToLDSOp :
let hasVerifier = 1;
}
+def AMDGPU_ClusterLoadAsyncToLDSOp :
+ AMDGPU_Op<"cluster_load_async_to_lds", [AttrSizedOperandSegments]>,
+ Arguments<(ins
+ Arg<AnyMemRef, "global memory to load from", [MemRead]>:$src,
+ Variadic<Index>:$srcIndices,
+ Arg<AnyMemRef, "LDS memory to write to", [MemWrite]>:$dst,
+ Variadic<Index>:$dstIndices,
+ I32:$clusterMask,
+ TypeAttr:$transferType
+ )>,
+ Results<(outs)> {
+ let summary = "MLIR wrapper for async cluster global load to LDS instructions";
+ let description = [{
+ AMDGPU wrapper for `cluster.load.async.to.lds` instructions, which broadcast
+ a global memory load to LDS across a cluster of workgroups while bypassing
+ VGPRs.
+
+ * `$src`: global memory memref to read from (global addrspace only, no fat buffer).
+ * `$srcIndices`: indices into `$src` for this thread's global read location.
+ * `$dst`: LDS memref to write to (workgroup addrspace).
+ * `$dstIndices`: indices into `$dst` for this thread's LDS write location.
+ * `$clusterMask`: i32 mask selecting the cluster participants.
+ * `$transferType`: type of data to be transferred. Must be an 8, 32, 64, or 128-bit
+ scalar or vector type.
+
+ Note: only supported on gfx1250 and later.
+
+ Example:
+ ```mlir
+ amdgpu.cluster_load_async_to_lds %src[%i, %j], %dst[%k, %l], %mask
+ : f32, memref<128x64xf32, #gpu.address_space<global>>,
+ memref<64x64xf32, #gpu.address_space<workgroup>>
+ ```
+ }];
+ let assemblyFormat = [{
+ $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $clusterMask
+ attr-dict `:` $transferType `,` type($src) `,` type($dst)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_TransposeLoadOp :
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 8464d1e29f0aa..967efd342a4f9 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2240,8 +2240,8 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
Location loc = op.getLoc();
- auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
- auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
+ MemRefType srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+ MemRefType dstMemRefType = cast<MemRefType>(op.getDst().getType());
// TODO: instead of only transfering one element per thread, we could
// augment it to transfer multiple elements per thread by issuing multiple
@@ -2289,6 +2289,13 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
}
};
+static int64_t getTransferSizeInBits(Type transferType) {
+ if (VectorType transferVectorType = dyn_cast<VectorType>(transferType))
+ return transferVectorType.getNumElements() *
+ transferVectorType.getElementTypeBitWidth();
+ return transferType.getIntOrFloatBitWidth();
+}
+
struct GlobalLoadAsyncToLDSOpLowering
: public ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp> {
GlobalLoadAsyncToLDSOpLowering(const LLVMTypeConverter &converter,
@@ -2311,11 +2318,7 @@ struct GlobalLoadAsyncToLDSOpLowering
auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
Type transferType = op.getTransferType();
- int transferBits =
- isa<VectorType>(transferType)
- ? cast<VectorType>(transferType).getNumElements() *
- cast<VectorType>(transferType).getElementTypeBitWidth()
- : transferType.getIntOrFloatBitWidth();
+ int64_t transferBits = getTransferSizeInBits(transferType);
Value srcPtr =
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
@@ -2366,6 +2369,68 @@ struct GlobalLoadAsyncToLDSOpLowering
}
};
+struct ClusterLoadAsyncToLDSOpLowering
+ : public ConvertOpToLLVMPattern<ClusterLoadAsyncToLDSOp> {
+ ClusterLoadAsyncToLDSOpLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<ClusterLoadAsyncToLDSOp>(converter),
+ chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ClusterLoadAsyncToLDSOp op,
+ ClusterLoadAsyncToLDSOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset < kGfx1250)
+ return op.emitOpError(
+ "cluster_load_async_to_lds is only supported on gfx1250+");
+
+ Location loc = op.getLoc();
+ auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+ auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
+
+ Type transferType = op.getTransferType();
+ int64_t transferBits = getTransferSizeInBits(transferType);
+
+ Value srcPtr =
+ getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+ adaptor.getSrcIndices());
+ Value dstPtr =
+ getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
+ adaptor.getDstIndices());
+ IntegerAttr offset = rewriter.getI32IntegerAttr(0);
+ IntegerAttr cpol = rewriter.getI32IntegerAttr(0);
+ Value clusterMask = adaptor.getClusterMask();
+
+ switch (transferBits) {
+ case 8:
+ rewriter.replaceOpWithNewOp<ROCDL::ClusterLoadAsyncToLDSB8Op>(
+ op, srcPtr, dstPtr, offset, cpol, clusterMask, ArrayAttr{},
+ ArrayAttr{}, ArrayAttr{});
+ break;
+ case 32:
+ rewriter.replaceOpWithNewOp<ROCDL::ClusterLoadAsyncToLDSB32Op>(
+ op, srcPtr, dstPtr, offset, cpol, clusterMask, ArrayAttr{},
+ ArrayAttr{}, ArrayAttr{});
+ break;
+ case 64:
+ rewriter.replaceOpWithNewOp<ROCDL::ClusterLoadAsyncToLDSB64Op>(
+ op, srcPtr, dstPtr, offset, cpol, clusterMask, ArrayAttr{},
+ ArrayAttr{}, ArrayAttr{});
+ break;
+ case 128:
+ rewriter.replaceOpWithNewOp<ROCDL::ClusterLoadAsyncToLDSB128Op>(
+ op, srcPtr, dstPtr, offset, cpol, clusterMask, ArrayAttr{},
+ ArrayAttr{}, ArrayAttr{});
+ break;
+ default:
+ return op.emitOpError("unsupported transfer width");
+ }
+ return success();
+ }
+};
+
namespace {
struct ExtPackedFp8OpLowering final
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
@@ -4407,8 +4472,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
- GlobalLoadAsyncToLDSOpLowering, TransposeLoadOpLowering,
- AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
+ GlobalLoadAsyncToLDSOpLowering, ClusterLoadAsyncToLDSOpLowering,
+ TransposeLoadOpLowering, AMDGPUPermlaneLowering,
+ AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
index 2f6f59194fba3..2f2fffcbcd871 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
@@ -147,6 +147,15 @@ static bool hasGlobalMemorySpace(Attribute memorySpace) {
return false;
}
+static bool hasExplicitGlobalMemorySpace(Attribute memorySpace) {
+ if (auto intMemorySpace = dyn_cast_if_present<IntegerAttr>(memorySpace))
+ return intMemorySpace.getInt() == 1;
+ if (auto gpuMemorySpace =
+ dyn_cast_if_present<gpu::AddressSpaceAttr>(memorySpace))
+ return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+ return false;
+}
+
static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
if (!memorySpace)
return false;
@@ -1016,33 +1025,80 @@ void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
// GlobalLoadAsyncToLDSOp
//===----------------------------------------------------------------------===//
-LogicalResult GlobalLoadAsyncToLDSOp::verify() {
- MemRefType srcType = cast<MemRefType>(getSrc().getType());
- MemRefType dstType = cast<MemRefType>(getDst().getType());
+static std::optional<int64_t> getTransferSizeInBits(Type transferType) {
+ Type elementType = transferType;
+ int64_t numElements = 1;
+ if (auto vectorType = dyn_cast<VectorType>(transferType)) {
+ elementType = vectorType.getElementType();
+ numElements = vectorType.getNumElements();
+ }
+
+ if (!elementType.isIntOrFloat())
+ return std::nullopt;
+
+ return numElements * elementType.getIntOrFloatBitWidth();
+}
+static LogicalResult
+verifyGlobalLoadAsyncToLDSLike(Operation *op, MemRefType srcType,
+ MemRefType dstType, OperandRange srcIndices,
+ OperandRange dstIndices, Type transferType,
+ bool requireExplicitGlobal) {
if (srcType.getElementType() != dstType.getElementType())
- return emitOpError("source and destination element types must match");
+ return op->emitOpError("source and destination element types must match");
- Type transferType = getTransferType();
- int transferSize;
- if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
- transferSize = vectorTransfer.getNumElements() *
- vectorTransfer.getElementTypeBitWidth();
- } else {
- transferSize = transferType.getIntOrFloatBitWidth();
- }
- if (!llvm::is_contained({8, 32, 64, 128}, transferSize))
- return emitOpError("transfer type size must be 8, 32, 64, or 128 bits");
+ if (srcType.getRank() != static_cast<int64_t>(srcIndices.size()))
+ return op->emitOpError("source index count must match source memref rank");
+
+ if (dstType.getRank() != static_cast<int64_t>(dstIndices.size()))
+ return op->emitOpError(
+ "destination index count must match destination memref rank");
+
+ std::optional<int64_t> transferSize = getTransferSizeInBits(transferType);
+ if (!transferSize)
+ return op->emitOpError(
+ "transfer type must be an integer, float, or vector of integers or "
+ "floats");
+
+ if (!llvm::is_contained({8, 32, 64, 128}, *transferSize))
+ return op->emitOpError("transfer type size must be 8, 32, 64, or 128 bits");
- if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
- return emitOpError("source memory address space must be global");
+ bool hasValidGlobalMemorySpace =
+ requireExplicitGlobal
+ ? hasExplicitGlobalMemorySpace(srcType.getMemorySpace())
+ : hasGlobalMemorySpace(srcType.getMemorySpace());
+ if (!hasValidGlobalMemorySpace)
+ return op->emitOpError("source memory address space must be global");
if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
- return emitOpError("destination memory address space must be Workgroup");
+ return op->emitOpError(
+ "destination memory address space must be Workgroup");
return success();
}
+LogicalResult GlobalLoadAsyncToLDSOp::verify() {
+ MemRefType srcType = cast<MemRefType>(getSrc().getType());
+ MemRefType dstType = cast<MemRefType>(getDst().getType());
+ return verifyGlobalLoadAsyncToLDSLike(*this, srcType, dstType,
+ getSrcIndices(), getDstIndices(),
+ getTransferType(),
+ /*requireExplicitGlobal=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// ClusterLoadAsyncToLDSOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ClusterLoadAsyncToLDSOp::verify() {
+ MemRefType srcType = cast<MemRefType>(getSrc().getType());
+ MemRefType dstType = cast<MemRefType>(getDst().getType());
+ return verifyGlobalLoadAsyncToLDSLike(*this, srcType, dstType,
+ getSrcIndices(), getDstIndices(),
+ getTransferType(),
+ /*requireExplicitGlobal=*/true);
+}
+
//===----------------------------------------------------------------------===//
// TransposeLoadOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
index e43ece8c74fdf..b6200657ef202 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
@@ -1031,3 +1031,67 @@ func.func @global_load_async_to_lds_b128_masked(
memref<64x64xf32, #gpu.address_space<workgroup>>
func.return
}
+
+// -----
+// cluster_load_async_to_lds_bN
+
+// CHECK-LABEL: func @cluster_load_async_to_lds_b32
+func.func @cluster_load_async_to_lds_b32(
+ %global : memref<128x72xf32, #gpu.address_space<global>>,
+ %mask : i32) {
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c32 = arith.constant 32 : index
+ %alloc = memref.alloc() : memref<64x64xf32, #gpu.address_space<workgroup>>
+ // CHECK: rocdl.cluster.load.async.to.lds.b32
+ amdgpu.cluster_load_async_to_lds %global[%c12, %c0],
+ %alloc[%c32, %c0], %mask
+ : f32, memref<128x72xf32, #gpu.address_space<global>>,
+ memref<64x64xf32, #gpu.address_space<workgroup>>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @cluster_load_async_to_lds_b8
+func.func @cluster_load_async_to_lds_b8(
+ %global : memref<128x72xi8, #gpu.address_space<global>>, %mask : i32) {
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() : memref<64x64xi8, #gpu.address_space<workgroup>>
+ // CHECK: rocdl.cluster.load.async.to.lds.b8
+ amdgpu.cluster_load_async_to_lds %global[%c0, %c0], %alloc[%c0, %c0],
+ %mask
+ : i8, memref<128x72xi8, #gpu.address_space<global>>,
+ memref<64x64xi8, #gpu.address_space<workgroup>>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @cluster_load_async_to_lds_b64
+func.func @cluster_load_async_to_lds_b64(
+ %global : memref<128x72xf32, #gpu.address_space<global>>,
+ %mask : i32) {
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() : memref<64x64xf32, #gpu.address_space<workgroup>>
+ // CHECK: rocdl.cluster.load.async.to.lds.b64 {{.*}}, {{.*}}, 0, 0, {{.*}}
+ amdgpu.cluster_load_async_to_lds %global[%c0, %c0], %alloc[%c0, %c0],
+ %mask
+ : vector<2xf32>, memref<128x72xf32, #gpu.address_space<global>>,
+ memref<64x64xf32, #gpu.address_space<workgroup>>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @cluster_load_async_to_lds_b128_dynamic_indices
+func.func @cluster_load_async_to_lds_b128_dynamic_indices(
+ %global : memref<512xi32, #gpu.address_space<global>>,
+ %src_idx : index, %dst_idx : index, %mask : i32) {
+ %alloc = memref.alloc() : memref<256xi32, #gpu.address_space<workgroup>>
+ // CHECK: rocdl.cluster.load.async.to.lds.b128
+ amdgpu.cluster_load_async_to_lds %global[%src_idx], %alloc[%dst_idx], %mask
+ : vector<4xi32>, memref<512xi32, #gpu.address_space<global>>,
+ memref<256xi32, #gpu.address_space<workgroup>>
+ func.return
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 2958b0fe2bc51..cd21be8b8b5e9 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -368,6 +368,99 @@ func.func @global_load_async_to_lds_src_not_global(%idx1 : index,
// -----
+func.func @global_load_async_to_lds_src_index_count(%idx1 : index,
+ %mem1 : memref<32x32xf32, #gpu.address_space<global>>,
+ %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
+ // expected-error at +1 {{'amdgpu.global_load_async_to_lds' op source index count must match source memref rank}}
+ amdgpu.global_load_async_to_lds %mem1[%idx1], %mem2[%idx1]
+ : f32, memref<32x32xf32, #gpu.address_space<global>>,
+ memref<32xf32, #gpu.address_space<workgroup>>
+ func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_non_lds(%idx1 : index, %mask : i32,
+ %mem1 : memref<32xf32, #gpu.address_space<global>>,
+ %mem2 : memref<32xf32>) {
+ // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op destination memory address space must be Workgroup}}
+ amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+ : f32, memref<32xf32, #gpu.address_space<global>>, memref<32xf32>
+ func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_bad_size_16bit(%idx1 : index, %mask : i32,
+ %mem1 : memref<32xf16, #gpu.address_space<global>>,
+ %mem2 : memref<32xf16, #gpu.address_space<workgroup>>) {
+ // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op transfer type size must be 8, 32, 64, or 128 bits}}
+ amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+ : f16, memref<32xf16, #gpu.address_space<global>>,
+ memref<32xf16, #gpu.address_space<workgroup>>
+ func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_src_not_global(%idx1 : index, %mask : i32,
+ %mem1 : memref<32xf32, #gpu.address_space<workgroup>>,
+ %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
+ // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op source memory address space must be global}}
+ amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+ : f32, memref<32xf32, #gpu.address_space<workgroup>>,
+ memref<32xf32, #gpu.address_space<workgroup>>
+ func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_src_default_memory_space(
+ %idx1 : index, %mask : i32, %mem1 : memref<32xf32>,
+ %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
+ // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op source memory address space must be global}}
+ amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+ : f32, memref<32xf32>, memref<32xf32, #gpu.address_space<workgroup>>
+ func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_src_memory_space_zero(
+ %idx1 : index, %mask : i32, %mem1 : memref<32xf32, 0>,
+ %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
+ // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op source memory address space must be global}}
+ amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+ : f32, memref<32xf32, 0>, memref<32xf32, #gpu.address_space<workgroup>>
+ func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_src_index_count(%idx1 : index,
+ %mask : i32, %mem1 : memref<32x32xf32, #gpu.address_space<global>>,
+ %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
+ // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op source index count must match source memref rank}}
+ amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+ : f32, memref<32x32xf32, #gpu.address_space<global>>,
+ memref<32xf32, #gpu.address_space<workgroup>>
+ func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_dst_index_count(%idx1 : index,
+ %mask : i32, %mem1 : memref<32xf32, #gpu.address_space<global>>,
+ %mem2 : memref<32x32xf32, #gpu.address_space<workgroup>>) {
+ // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op destination index count must match destination memref rank}}
+ amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+ : f32, memref<32xf32, #gpu.address_space<global>>,
+ memref<32x32xf32, #gpu.address_space<workgroup>>
+ func.return
+}
+
+// -----
+
func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
// expected-error at +1 {{'amdgpu.scaled_mfma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
%0 = amdgpu.scaled_mfma 8x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 606a7768974bf..918587214411f 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -730,6 +730,42 @@ func.func @global_load_async_to_lds_0d(%mem1 : memref<f32, #gpu.address_space<gl
func.return
}
+// CHECK-LABEL: func @cluster_load_async_to_lds
+func.func @cluster_load_async_to_lds(
+ %idx1 : index, %idx2 : index,
+ %mem1 : memref<32xf32, #gpu.address_space<global>>,
+ %mem2 : memref<32x32xf32, #gpu.address_space<global>>,
+ %smem2 : memref<32x32xf32, #gpu.address_space<workgroup>>,
+ %mask : i32) {
+ // CHECK: amdgpu.cluster_load_async_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}
+ // CHECK: amdgpu.cluster_load_async_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}
+ // CHECK: amdgpu.cluster_load_async_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}
+ amdgpu.cluster_load_async_to_lds %mem2[%idx1, %idx2],
+ %smem2[%idx1, %idx2], %mask
+ : f32, memref<32x32xf32, #gpu.address_space<global>>,
+ memref<32x32xf32, #gpu.address_space<workgroup>>
+ amdgpu.cluster_load_async_to_lds %mem1[%idx1],
+ %smem2[%idx1, %idx2], %mask
+ : f32, memref<32xf32, #gpu.address_space<global>>,
+ memref<32x32xf32, #gpu.address_space<workgroup>>
+ amdgpu.cluster_load_async_to_lds %mem2[%idx1, %idx2],
+ %smem2[%idx1, %idx2], %mask
+ : vector<2xf32>, memref<32x32xf32, #gpu.address_space<global>>,
+ memref<32x32xf32, #gpu.address_space<workgroup>>
+ func.return
+}
+
+// CHECK-LABEL: func @cluster_load_async_to_lds_0d
+func.func @cluster_load_async_to_lds_0d(
+ %mem1 : memref<f32, #gpu.address_space<global>>,
+ %smem1 : memref<f32, #gpu.address_space<workgroup>>, %mask : i32) {
+ // CHECK: amdgpu.cluster_load_async_to_lds %{{.*}}[], %{{.*}}[], %{{.*}}
+ amdgpu.cluster_load_async_to_lds %mem1[], %smem1[], %mask
+ : f32, memref<f32, #gpu.address_space<global>>,
+ memref<f32, #gpu.address_space<workgroup>>
+ func.return
+}
+
// CHECK-LABEL: func @memory_counter_wait
func.func @memory_counter_wait() {
// CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5)
>From 9cf37dc28e63fa61cb1bcb80361986dd5f8550b4 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Sat, 2 May 2026 22:36:43 -0700
Subject: [PATCH 2/2] polish
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
.../mlir/Dialect/AMDGPU/IR/AMDGPUOps.td | 8 ++--
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 10 ++---
mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp | 40 ++++++-------------
mlir/test/Dialect/AMDGPU/invalid.mlir | 22 ----------
4 files changed, 21 insertions(+), 59 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
index e6b851153ce92..8843e445fb3fa 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
@@ -1430,15 +1430,15 @@ def AMDGPU_ClusterLoadAsyncToLDSOp :
let summary = "MLIR wrapper for async cluster global load to LDS instructions";
let description = [{
AMDGPU wrapper for `cluster.load.async.to.lds` instructions, which broadcast
- a global memory load to LDS across a cluster of workgroups while bypassing
- VGPRs.
+ a global memory load to LDS from a single wave (subgroup) across a cluster of
+ workgroups while bypassing VGPRs.
* `$src`: global memory memref to read from (global addrspace only, no fat buffer).
* `$srcIndices`: indices into `$src` for this thread's global read location.
* `$dst`: LDS memref to write to (workgroup addrspace).
* `$dstIndices`: indices into `$dst` for this thread's LDS write location.
- * `$clusterMask`: i32 mask selecting the cluster participants.
- * `$transferType`: type of data to be transferred. Must be an 8, 32, 64, or 128-bit
+ * `$clusterMask`: i32 mask selecting the workgroups which are part of the cluster.
+ * `$transferType`: type of data to be transferred. Must be an 8, 32, 64, or 128 bit
scalar or vector type.
Note: only supported on gfx1250 and later.
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 967efd342a4f9..470ef56f377d0 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2240,8 +2240,8 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
Location loc = op.getLoc();
- MemRefType srcMemRefType = cast<MemRefType>(op.getSrc().getType());
- MemRefType dstMemRefType = cast<MemRefType>(op.getDst().getType());
+ auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+ auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
// TODO: instead of only transfering one element per thread, we could
// augment it to transfer multiple elements per thread by issuing multiple
@@ -2289,7 +2289,7 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
}
};
-static int64_t getTransferSizeInBits(Type transferType) {
+static unsigned getTransferSizeInBits(Type transferType) {
if (VectorType transferVectorType = dyn_cast<VectorType>(transferType))
return transferVectorType.getNumElements() *
transferVectorType.getElementTypeBitWidth();
@@ -2318,7 +2318,7 @@ struct GlobalLoadAsyncToLDSOpLowering
auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
Type transferType = op.getTransferType();
- int64_t transferBits = getTransferSizeInBits(transferType);
+ unsigned transferBits = getTransferSizeInBits(transferType);
Value srcPtr =
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
@@ -2391,7 +2391,7 @@ struct ClusterLoadAsyncToLDSOpLowering
auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
Type transferType = op.getTransferType();
- int64_t transferBits = getTransferSizeInBits(transferType);
+ unsigned transferBits = getTransferSizeInBits(transferType);
Value srcPtr =
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
index 2f2fffcbcd871..ff810f6ad4093 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
@@ -147,15 +147,6 @@ static bool hasGlobalMemorySpace(Attribute memorySpace) {
return false;
}
-static bool hasExplicitGlobalMemorySpace(Attribute memorySpace) {
- if (auto intMemorySpace = dyn_cast_if_present<IntegerAttr>(memorySpace))
- return intMemorySpace.getInt() == 1;
- if (auto gpuMemorySpace =
- dyn_cast_if_present<gpu::AddressSpaceAttr>(memorySpace))
- return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
- return false;
-}
-
static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
if (!memorySpace)
return false;
@@ -1025,9 +1016,9 @@ void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
// GlobalLoadAsyncToLDSOp
//===----------------------------------------------------------------------===//
-static std::optional<int64_t> getTransferSizeInBits(Type transferType) {
+static std::optional<unsigned> getTransferSizeInBits(Type transferType) {
Type elementType = transferType;
- int64_t numElements = 1;
+ unsigned numElements = 1;
if (auto vectorType = dyn_cast<VectorType>(transferType)) {
elementType = vectorType.getElementType();
numElements = vectorType.getNumElements();
@@ -1042,8 +1033,7 @@ static std::optional<int64_t> getTransferSizeInBits(Type transferType) {
static LogicalResult
verifyGlobalLoadAsyncToLDSLike(Operation *op, MemRefType srcType,
MemRefType dstType, OperandRange srcIndices,
- OperandRange dstIndices, Type transferType,
- bool requireExplicitGlobal) {
+ OperandRange dstIndices, Type transferType) {
if (srcType.getElementType() != dstType.getElementType())
return op->emitOpError("source and destination element types must match");
@@ -1054,20 +1044,16 @@ verifyGlobalLoadAsyncToLDSLike(Operation *op, MemRefType srcType,
return op->emitOpError(
"destination index count must match destination memref rank");
- std::optional<int64_t> transferSize = getTransferSizeInBits(transferType);
+ std::optional<unsigned> transferSize = getTransferSizeInBits(transferType);
if (!transferSize)
return op->emitOpError(
"transfer type must be an integer, float, or vector of integers or "
"floats");
- if (!llvm::is_contained({8, 32, 64, 128}, *transferSize))
+ if (!llvm::is_contained({8u, 32u, 64u, 128u}, transferSize.value()))
return op->emitOpError("transfer type size must be 8, 32, 64, or 128 bits");
- bool hasValidGlobalMemorySpace =
- requireExplicitGlobal
- ? hasExplicitGlobalMemorySpace(srcType.getMemorySpace())
- : hasGlobalMemorySpace(srcType.getMemorySpace());
- if (!hasValidGlobalMemorySpace)
+ if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
return op->emitOpError("source memory address space must be global");
if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
@@ -1078,12 +1064,11 @@ verifyGlobalLoadAsyncToLDSLike(Operation *op, MemRefType srcType,
}
LogicalResult GlobalLoadAsyncToLDSOp::verify() {
- MemRefType srcType = cast<MemRefType>(getSrc().getType());
- MemRefType dstType = cast<MemRefType>(getDst().getType());
+ auto srcType = cast<MemRefType>(getSrc().getType());
+ auto dstType = cast<MemRefType>(getDst().getType());
return verifyGlobalLoadAsyncToLDSLike(*this, srcType, dstType,
getSrcIndices(), getDstIndices(),
- getTransferType(),
- /*requireExplicitGlobal=*/false);
+ getTransferType());
}
//===----------------------------------------------------------------------===//
@@ -1091,12 +1076,11 @@ LogicalResult GlobalLoadAsyncToLDSOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult ClusterLoadAsyncToLDSOp::verify() {
- MemRefType srcType = cast<MemRefType>(getSrc().getType());
- MemRefType dstType = cast<MemRefType>(getDst().getType());
+ auto srcType = cast<MemRefType>(getSrc().getType());
+ auto dstType = cast<MemRefType>(getDst().getType());
return verifyGlobalLoadAsyncToLDSLike(*this, srcType, dstType,
getSrcIndices(), getDstIndices(),
- getTransferType(),
- /*requireExplicitGlobal=*/true);
+ getTransferType());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index cd21be8b8b5e9..1d00edfaedac0 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -415,28 +415,6 @@ func.func @cluster_load_async_to_lds_src_not_global(%idx1 : index, %mask : i32,
// -----
-func.func @cluster_load_async_to_lds_src_default_memory_space(
- %idx1 : index, %mask : i32, %mem1 : memref<32xf32>,
- %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
- // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op source memory address space must be global}}
- amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
- : f32, memref<32xf32>, memref<32xf32, #gpu.address_space<workgroup>>
- func.return
-}
-
-// -----
-
-func.func @cluster_load_async_to_lds_src_memory_space_zero(
- %idx1 : index, %mask : i32, %mem1 : memref<32xf32, 0>,
- %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
- // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op source memory address space must be global}}
- amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
- : f32, memref<32xf32, 0>, memref<32xf32, #gpu.address_space<workgroup>>
- func.return
-}
-
-// -----
-
func.func @cluster_load_async_to_lds_src_index_count(%idx1 : index,
%mask : i32, %mem1 : memref<32x32xf32, #gpu.address_space<global>>,
%mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
More information about the Mlir-commits
mailing list