[Mlir-commits] [mlir] [amdgpu][mlir] implement amdgpu.cluster_load_async_to_lds (PR #195410)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat May 2 22:39:58 PDT 2026


https://github.com/efric updated https://github.com/llvm/llvm-project/pull/195410

>From 0545b33b9c6fe66f4bd6147deed3930affd2588b Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 1 May 2026 21:20:54 -0700
Subject: [PATCH 1/2] cluster load to lds

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 .../mlir/Dialect/AMDGPU/IR/AMDGPUOps.td       | 44 ++++++++-
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 84 +++++++++++++++--
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp      | 90 ++++++++++++++----
 .../Conversion/AMDGPUToROCDL/gfx1250.mlir     | 64 +++++++++++++
 mlir/test/Dialect/AMDGPU/invalid.mlir         | 93 +++++++++++++++++++
 mlir/test/Dialect/AMDGPU/ops.mlir             | 36 +++++++
 6 files changed, 384 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
index 4112ea281bb96..e6b851153ce92 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
@@ -1382,7 +1382,8 @@ def AMDGPU_GlobalLoadAsyncToLDSOp :
     AMDGPU wrapper for `global.load.async.to.lds` instructions, which performs
     asynchronous load of data from global memory into LDS while bypassing VGPRs.
 
-    * `$src`: global memory memref to read from (global addrspace only, no fat buffer).
+    * `$src`: global memory memref to read from (global addrspace only, no
+      fat buffer).
     * `$srcIndices`: indices into `$src` for this thread's global read location.
     * `$dst`: LDS memref to write to (workgroup addrspace).
     * `$dstIndices`: indices into `$dst` for this thread's LDS write location.
@@ -1415,6 +1416,47 @@ def AMDGPU_GlobalLoadAsyncToLDSOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_ClusterLoadAsyncToLDSOp :
+    AMDGPU_Op<"cluster_load_async_to_lds", [AttrSizedOperandSegments]>,
+    Arguments<(ins
+                   Arg<AnyMemRef, "global memory to load from", [MemRead]>:$src,
+                   Variadic<Index>:$srcIndices,
+                   Arg<AnyMemRef, "LDS memory to write to", [MemWrite]>:$dst,
+                   Variadic<Index>:$dstIndices,
+                   I32:$clusterMask,
+                   TypeAttr:$transferType
+                   )>,
+    Results<(outs)> {
+  let summary = "MLIR wrapper for async cluster global load to LDS instructions";
+  let description = [{
+    AMDGPU wrapper for `cluster.load.async.to.lds` instructions, which broadcast
+    a global memory load to LDS across a cluster of workgroups while bypassing
+    VGPRs.
+
+    * `$src`: global memory memref to read from (global addrspace only, no fat buffer).
+    * `$srcIndices`: indices into `$src` for this thread's global read location.
+    * `$dst`: LDS memref to write to (workgroup addrspace).
+    * `$dstIndices`: indices into `$dst` for this thread's LDS write location.
+    * `$clusterMask`: i32 mask selecting the cluster participants.
+    * `$transferType`: type of data to be transferred. Must be an 8, 32, 64, or 128-bit
+      scalar or vector type.
+
+    Note: only supported on gfx1250 and later.
+
+    Example:
+    ```mlir
+      amdgpu.cluster_load_async_to_lds %src[%i, %j], %dst[%k, %l], %mask
+        : f32, memref<128x64xf32, #gpu.address_space<global>>,
+          memref<64x64xf32, #gpu.address_space<workgroup>>
+    ```
+  }];
+  let assemblyFormat = [{
+    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $clusterMask
+    attr-dict `:` $transferType `,` type($src) `,` type($dst)
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_TransposeLoadOp :
     AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
     Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 8464d1e29f0aa..967efd342a4f9 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2240,8 +2240,8 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
 
     Location loc = op.getLoc();
 
-    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
-    auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
+    MemRefType srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+    MemRefType dstMemRefType = cast<MemRefType>(op.getDst().getType());
 
     // TODO: instead of only transfering one element per thread, we could
     // augment it to transfer multiple elements per thread by issuing multiple
@@ -2289,6 +2289,13 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
   }
 };
 
+static int64_t getTransferSizeInBits(Type transferType) {
+  if (VectorType transferVectorType = dyn_cast<VectorType>(transferType))
+    return transferVectorType.getNumElements() *
+           transferVectorType.getElementTypeBitWidth();
+  return transferType.getIntOrFloatBitWidth();
+}
+
 struct GlobalLoadAsyncToLDSOpLowering
     : public ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp> {
   GlobalLoadAsyncToLDSOpLowering(const LLVMTypeConverter &converter,
@@ -2311,11 +2318,7 @@ struct GlobalLoadAsyncToLDSOpLowering
     auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
 
     Type transferType = op.getTransferType();
-    int transferBits =
-        isa<VectorType>(transferType)
-            ? cast<VectorType>(transferType).getNumElements() *
-                  cast<VectorType>(transferType).getElementTypeBitWidth()
-            : transferType.getIntOrFloatBitWidth();
+    int64_t transferBits = getTransferSizeInBits(transferType);
 
     Value srcPtr =
         getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
@@ -2366,6 +2369,68 @@ struct GlobalLoadAsyncToLDSOpLowering
   }
 };
 
+struct ClusterLoadAsyncToLDSOpLowering
+    : public ConvertOpToLLVMPattern<ClusterLoadAsyncToLDSOp> {
+  ClusterLoadAsyncToLDSOpLowering(const LLVMTypeConverter &converter,
+                                  Chipset chipset)
+      : ConvertOpToLLVMPattern<ClusterLoadAsyncToLDSOp>(converter),
+        chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ClusterLoadAsyncToLDSOp op,
+                  ClusterLoadAsyncToLDSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx1250)
+      return op.emitOpError(
+          "cluster_load_async_to_lds is only supported on gfx1250+");
+
+    Location loc = op.getLoc();
+    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+    auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
+
+    Type transferType = op.getTransferType();
+    int64_t transferBits = getTransferSizeInBits(transferType);
+
+    Value srcPtr =
+        getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+                             adaptor.getSrcIndices());
+    Value dstPtr =
+        getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
+                             adaptor.getDstIndices());
+    IntegerAttr offset = rewriter.getI32IntegerAttr(0);
+    IntegerAttr cpol = rewriter.getI32IntegerAttr(0);
+    Value clusterMask = adaptor.getClusterMask();
+
+    switch (transferBits) {
+    case 8:
+      rewriter.replaceOpWithNewOp<ROCDL::ClusterLoadAsyncToLDSB8Op>(
+          op, srcPtr, dstPtr, offset, cpol, clusterMask, ArrayAttr{},
+          ArrayAttr{}, ArrayAttr{});
+      break;
+    case 32:
+      rewriter.replaceOpWithNewOp<ROCDL::ClusterLoadAsyncToLDSB32Op>(
+          op, srcPtr, dstPtr, offset, cpol, clusterMask, ArrayAttr{},
+          ArrayAttr{}, ArrayAttr{});
+      break;
+    case 64:
+      rewriter.replaceOpWithNewOp<ROCDL::ClusterLoadAsyncToLDSB64Op>(
+          op, srcPtr, dstPtr, offset, cpol, clusterMask, ArrayAttr{},
+          ArrayAttr{}, ArrayAttr{});
+      break;
+    case 128:
+      rewriter.replaceOpWithNewOp<ROCDL::ClusterLoadAsyncToLDSB128Op>(
+          op, srcPtr, dstPtr, offset, cpol, clusterMask, ArrayAttr{},
+          ArrayAttr{}, ArrayAttr{});
+      break;
+    default:
+      return op.emitOpError("unsupported transfer width");
+    }
+    return success();
+  }
+};
+
 namespace {
 struct ExtPackedFp8OpLowering final
     : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
@@ -4407,8 +4472,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
            PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
            PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
-           GlobalLoadAsyncToLDSOpLowering, TransposeLoadOpLowering,
-           AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
+           GlobalLoadAsyncToLDSOpLowering, ClusterLoadAsyncToLDSOpLowering,
+           TransposeLoadOpLowering, AMDGPUPermlaneLowering,
+           AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
            AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
            AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
            AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
index 2f6f59194fba3..2f2fffcbcd871 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
@@ -147,6 +147,15 @@ static bool hasGlobalMemorySpace(Attribute memorySpace) {
   return false;
 }
 
+static bool hasExplicitGlobalMemorySpace(Attribute memorySpace) {
+  if (auto intMemorySpace = dyn_cast_if_present<IntegerAttr>(memorySpace))
+    return intMemorySpace.getInt() == 1;
+  if (auto gpuMemorySpace =
+          dyn_cast_if_present<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+  return false;
+}
+
 static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
   if (!memorySpace)
     return false;
@@ -1016,33 +1025,80 @@ void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // GlobalLoadAsyncToLDSOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult GlobalLoadAsyncToLDSOp::verify() {
-  MemRefType srcType = cast<MemRefType>(getSrc().getType());
-  MemRefType dstType = cast<MemRefType>(getDst().getType());
+static std::optional<int64_t> getTransferSizeInBits(Type transferType) {
+  Type elementType = transferType;
+  int64_t numElements = 1;
+  if (auto vectorType = dyn_cast<VectorType>(transferType)) {
+    elementType = vectorType.getElementType();
+    numElements = vectorType.getNumElements();
+  }
+
+  if (!elementType.isIntOrFloat())
+    return std::nullopt;
+
+  return numElements * elementType.getIntOrFloatBitWidth();
+}
 
+static LogicalResult
+verifyGlobalLoadAsyncToLDSLike(Operation *op, MemRefType srcType,
+                               MemRefType dstType, OperandRange srcIndices,
+                               OperandRange dstIndices, Type transferType,
+                               bool requireExplicitGlobal) {
   if (srcType.getElementType() != dstType.getElementType())
-    return emitOpError("source and destination element types must match");
+    return op->emitOpError("source and destination element types must match");
 
-  Type transferType = getTransferType();
-  int transferSize;
-  if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
-    transferSize = vectorTransfer.getNumElements() *
-                   vectorTransfer.getElementTypeBitWidth();
-  } else {
-    transferSize = transferType.getIntOrFloatBitWidth();
-  }
-  if (!llvm::is_contained({8, 32, 64, 128}, transferSize))
-    return emitOpError("transfer type size must be 8, 32, 64, or 128 bits");
+  if (srcType.getRank() != static_cast<int64_t>(srcIndices.size()))
+    return op->emitOpError("source index count must match source memref rank");
+
+  if (dstType.getRank() != static_cast<int64_t>(dstIndices.size()))
+    return op->emitOpError(
+        "destination index count must match destination memref rank");
+
+  std::optional<int64_t> transferSize = getTransferSizeInBits(transferType);
+  if (!transferSize)
+    return op->emitOpError(
+        "transfer type must be an integer, float, or vector of integers or "
+        "floats");
+
+  if (!llvm::is_contained({8, 32, 64, 128}, *transferSize))
+    return op->emitOpError("transfer type size must be 8, 32, 64, or 128 bits");
 
-  if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
-    return emitOpError("source memory address space must be global");
+  bool hasValidGlobalMemorySpace =
+      requireExplicitGlobal
+          ? hasExplicitGlobalMemorySpace(srcType.getMemorySpace())
+          : hasGlobalMemorySpace(srcType.getMemorySpace());
+  if (!hasValidGlobalMemorySpace)
+    return op->emitOpError("source memory address space must be global");
 
   if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
-    return emitOpError("destination memory address space must be Workgroup");
+    return op->emitOpError(
+        "destination memory address space must be Workgroup");
 
   return success();
 }
 
+LogicalResult GlobalLoadAsyncToLDSOp::verify() {
+  MemRefType srcType = cast<MemRefType>(getSrc().getType());
+  MemRefType dstType = cast<MemRefType>(getDst().getType());
+  return verifyGlobalLoadAsyncToLDSLike(*this, srcType, dstType,
+                                        getSrcIndices(), getDstIndices(),
+                                        getTransferType(),
+                                        /*requireExplicitGlobal=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// ClusterLoadAsyncToLDSOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ClusterLoadAsyncToLDSOp::verify() {
+  MemRefType srcType = cast<MemRefType>(getSrc().getType());
+  MemRefType dstType = cast<MemRefType>(getDst().getType());
+  return verifyGlobalLoadAsyncToLDSLike(*this, srcType, dstType,
+                                        getSrcIndices(), getDstIndices(),
+                                        getTransferType(),
+                                        /*requireExplicitGlobal=*/true);
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeLoadOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
index e43ece8c74fdf..b6200657ef202 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
@@ -1031,3 +1031,67 @@ func.func @global_load_async_to_lds_b128_masked(
       memref<64x64xf32, #gpu.address_space<workgroup>>
   func.return
 }
+
+// -----
+// cluster_load_async_to_lds_bN
+
+// CHECK-LABEL: func @cluster_load_async_to_lds_b32
+func.func @cluster_load_async_to_lds_b32(
+    %global : memref<128x72xf32, #gpu.address_space<global>>,
+    %mask : i32) {
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c32 = arith.constant 32 : index
+  %alloc = memref.alloc() : memref<64x64xf32, #gpu.address_space<workgroup>>
+  // CHECK: rocdl.cluster.load.async.to.lds.b32
+  amdgpu.cluster_load_async_to_lds %global[%c12, %c0],
+    %alloc[%c32, %c0], %mask
+    : f32, memref<128x72xf32, #gpu.address_space<global>>,
+      memref<64x64xf32, #gpu.address_space<workgroup>>
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @cluster_load_async_to_lds_b8
+func.func @cluster_load_async_to_lds_b8(
+    %global : memref<128x72xi8, #gpu.address_space<global>>, %mask : i32) {
+  %c0 = arith.constant 0 : index
+  %alloc = memref.alloc() : memref<64x64xi8, #gpu.address_space<workgroup>>
+  // CHECK: rocdl.cluster.load.async.to.lds.b8
+  amdgpu.cluster_load_async_to_lds %global[%c0, %c0], %alloc[%c0, %c0],
+    %mask
+    : i8, memref<128x72xi8, #gpu.address_space<global>>,
+      memref<64x64xi8, #gpu.address_space<workgroup>>
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @cluster_load_async_to_lds_b64
+func.func @cluster_load_async_to_lds_b64(
+    %global : memref<128x72xf32, #gpu.address_space<global>>,
+    %mask : i32) {
+  %c0 = arith.constant 0 : index
+  %alloc = memref.alloc() : memref<64x64xf32, #gpu.address_space<workgroup>>
+  // CHECK: rocdl.cluster.load.async.to.lds.b64 {{.*}}, {{.*}}, 0, 0, {{.*}}
+  amdgpu.cluster_load_async_to_lds %global[%c0, %c0], %alloc[%c0, %c0],
+    %mask
+    : vector<2xf32>, memref<128x72xf32, #gpu.address_space<global>>,
+      memref<64x64xf32, #gpu.address_space<workgroup>>
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @cluster_load_async_to_lds_b128_dynamic_indices
+func.func @cluster_load_async_to_lds_b128_dynamic_indices(
+    %global : memref<512xi32, #gpu.address_space<global>>,
+    %src_idx : index, %dst_idx : index, %mask : i32) {
+  %alloc = memref.alloc() : memref<256xi32, #gpu.address_space<workgroup>>
+  // CHECK: rocdl.cluster.load.async.to.lds.b128
+  amdgpu.cluster_load_async_to_lds %global[%src_idx], %alloc[%dst_idx], %mask
+    : vector<4xi32>, memref<512xi32, #gpu.address_space<global>>,
+      memref<256xi32, #gpu.address_space<workgroup>>
+  func.return
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 2958b0fe2bc51..cd21be8b8b5e9 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -368,6 +368,99 @@ func.func @global_load_async_to_lds_src_not_global(%idx1 : index,
 
 // -----
 
+func.func @global_load_async_to_lds_src_index_count(%idx1 : index,
+    %mem1 : memref<32x32xf32, #gpu.address_space<global>>,
+    %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
+  // expected-error at +1 {{'amdgpu.global_load_async_to_lds' op source index count must match source memref rank}}
+  amdgpu.global_load_async_to_lds %mem1[%idx1], %mem2[%idx1]
+    : f32, memref<32x32xf32, #gpu.address_space<global>>,
+      memref<32xf32, #gpu.address_space<workgroup>>
+  func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_non_lds(%idx1 : index, %mask : i32,
+    %mem1 : memref<32xf32, #gpu.address_space<global>>,
+    %mem2 : memref<32xf32>) {
+  // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op destination memory address space must be Workgroup}}
+  amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+    : f32, memref<32xf32, #gpu.address_space<global>>, memref<32xf32>
+  func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_bad_size_16bit(%idx1 : index, %mask : i32,
+    %mem1 : memref<32xf16, #gpu.address_space<global>>,
+    %mem2 : memref<32xf16, #gpu.address_space<workgroup>>) {
+  // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op transfer type size must be 8, 32, 64, or 128 bits}}
+  amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+    : f16, memref<32xf16, #gpu.address_space<global>>,
+      memref<32xf16, #gpu.address_space<workgroup>>
+  func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_src_not_global(%idx1 : index, %mask : i32,
+    %mem1 : memref<32xf32, #gpu.address_space<workgroup>>,
+    %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
+  // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op source memory address space must be global}}
+  amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+    : f32, memref<32xf32, #gpu.address_space<workgroup>>,
+      memref<32xf32, #gpu.address_space<workgroup>>
+  func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_src_default_memory_space(
+    %idx1 : index, %mask : i32, %mem1 : memref<32xf32>,
+    %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
+  // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op source memory address space must be global}}
+  amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+    : f32, memref<32xf32>, memref<32xf32, #gpu.address_space<workgroup>>
+  func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_src_memory_space_zero(
+    %idx1 : index, %mask : i32, %mem1 : memref<32xf32, 0>,
+    %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
+  // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op source memory address space must be global}}
+  amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+    : f32, memref<32xf32, 0>, memref<32xf32, #gpu.address_space<workgroup>>
+  func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_src_index_count(%idx1 : index,
+    %mask : i32, %mem1 : memref<32x32xf32, #gpu.address_space<global>>,
+    %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
+  // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op source index count must match source memref rank}}
+  amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+    : f32, memref<32x32xf32, #gpu.address_space<global>>,
+      memref<32xf32, #gpu.address_space<workgroup>>
+  func.return
+}
+
+// -----
+
+func.func @cluster_load_async_to_lds_dst_index_count(%idx1 : index,
+    %mask : i32, %mem1 : memref<32xf32, #gpu.address_space<global>>,
+    %mem2 : memref<32x32xf32, #gpu.address_space<workgroup>>) {
+  // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op destination index count must match destination memref rank}}
+  amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
+    : f32, memref<32xf32, #gpu.address_space<global>>,
+      memref<32x32xf32, #gpu.address_space<workgroup>>
+  func.return
+}
+
+// -----
+
 func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
   // expected-error at +1 {{'amdgpu.scaled_mfma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
   %0 = amdgpu.scaled_mfma 8x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 606a7768974bf..918587214411f 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -730,6 +730,42 @@ func.func @global_load_async_to_lds_0d(%mem1 : memref<f32, #gpu.address_space<gl
   func.return
 }
 
+// CHECK-LABEL: func @cluster_load_async_to_lds
+func.func @cluster_load_async_to_lds(
+    %idx1 : index, %idx2 : index,
+    %mem1 : memref<32xf32, #gpu.address_space<global>>,
+    %mem2 : memref<32x32xf32, #gpu.address_space<global>>,
+    %smem2 : memref<32x32xf32, #gpu.address_space<workgroup>>,
+    %mask : i32) {
+  // CHECK: amdgpu.cluster_load_async_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}
+  // CHECK: amdgpu.cluster_load_async_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}
+  // CHECK: amdgpu.cluster_load_async_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}
+  amdgpu.cluster_load_async_to_lds %mem2[%idx1, %idx2],
+    %smem2[%idx1, %idx2], %mask
+    : f32, memref<32x32xf32, #gpu.address_space<global>>,
+      memref<32x32xf32, #gpu.address_space<workgroup>>
+  amdgpu.cluster_load_async_to_lds %mem1[%idx1],
+    %smem2[%idx1, %idx2], %mask
+    : f32, memref<32xf32, #gpu.address_space<global>>,
+      memref<32x32xf32, #gpu.address_space<workgroup>>
+  amdgpu.cluster_load_async_to_lds %mem2[%idx1, %idx2],
+    %smem2[%idx1, %idx2], %mask
+    : vector<2xf32>, memref<32x32xf32, #gpu.address_space<global>>,
+      memref<32x32xf32, #gpu.address_space<workgroup>>
+  func.return
+}
+
+// CHECK-LABEL: func @cluster_load_async_to_lds_0d
+func.func @cluster_load_async_to_lds_0d(
+    %mem1 : memref<f32, #gpu.address_space<global>>,
+    %smem1 : memref<f32, #gpu.address_space<workgroup>>, %mask : i32) {
+  // CHECK: amdgpu.cluster_load_async_to_lds %{{.*}}[], %{{.*}}[], %{{.*}}
+  amdgpu.cluster_load_async_to_lds %mem1[], %smem1[], %mask
+    : f32, memref<f32, #gpu.address_space<global>>,
+      memref<f32, #gpu.address_space<workgroup>>
+  func.return
+}
+
 // CHECK-LABEL: func @memory_counter_wait
 func.func @memory_counter_wait() {
   // CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5)

>From 9cf37dc28e63fa61cb1bcb80361986dd5f8550b4 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Sat, 2 May 2026 22:36:43 -0700
Subject: [PATCH 2/2] polish

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 .../mlir/Dialect/AMDGPU/IR/AMDGPUOps.td       |  8 ++--
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 10 ++---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp      | 40 ++++++-------------
 mlir/test/Dialect/AMDGPU/invalid.mlir         | 22 ----------
 4 files changed, 21 insertions(+), 59 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
index e6b851153ce92..8843e445fb3fa 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
@@ -1430,15 +1430,15 @@ def AMDGPU_ClusterLoadAsyncToLDSOp :
   let summary = "MLIR wrapper for async cluster global load to LDS instructions";
   let description = [{
     AMDGPU wrapper for `cluster.load.async.to.lds` instructions, which broadcast
-    a global memory load to LDS across a cluster of workgroups while bypassing
-    VGPRs.
+    a global memory load to LDS from a single wave (subgroup) across a cluster of
+    workgroups while bypassing VGPRs.
 
     * `$src`: global memory memref to read from (global addrspace only, no fat buffer).
     * `$srcIndices`: indices into `$src` for this thread's global read location.
     * `$dst`: LDS memref to write to (workgroup addrspace).
     * `$dstIndices`: indices into `$dst` for this thread's LDS write location.
-    * `$clusterMask`: i32 mask selecting the cluster participants.
-    * `$transferType`: type of data to be transferred. Must be an 8, 32, 64, or 128-bit
+    * `$clusterMask`: i32 mask selecting the workgroups which are part of the cluster.
+    * `$transferType`: type of data to be transferred. Must be an 8, 32, 64, or 128 bit
       scalar or vector type.
 
     Note: only supported on gfx1250 and later.
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 967efd342a4f9..470ef56f377d0 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2240,8 +2240,8 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
 
     Location loc = op.getLoc();
 
-    MemRefType srcMemRefType = cast<MemRefType>(op.getSrc().getType());
-    MemRefType dstMemRefType = cast<MemRefType>(op.getDst().getType());
+    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+    auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
 
     // TODO: instead of only transfering one element per thread, we could
     // augment it to transfer multiple elements per thread by issuing multiple
@@ -2289,7 +2289,7 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
   }
 };
 
-static int64_t getTransferSizeInBits(Type transferType) {
+static unsigned getTransferSizeInBits(Type transferType) {
   if (VectorType transferVectorType = dyn_cast<VectorType>(transferType))
     return transferVectorType.getNumElements() *
            transferVectorType.getElementTypeBitWidth();
@@ -2318,7 +2318,7 @@ struct GlobalLoadAsyncToLDSOpLowering
     auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
 
     Type transferType = op.getTransferType();
-    int64_t transferBits = getTransferSizeInBits(transferType);
+    unsigned transferBits = getTransferSizeInBits(transferType);
 
     Value srcPtr =
         getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
@@ -2391,7 +2391,7 @@ struct ClusterLoadAsyncToLDSOpLowering
     auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
 
     Type transferType = op.getTransferType();
-    int64_t transferBits = getTransferSizeInBits(transferType);
+    unsigned transferBits = getTransferSizeInBits(transferType);
 
     Value srcPtr =
         getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
index 2f2fffcbcd871..ff810f6ad4093 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
@@ -147,15 +147,6 @@ static bool hasGlobalMemorySpace(Attribute memorySpace) {
   return false;
 }
 
-static bool hasExplicitGlobalMemorySpace(Attribute memorySpace) {
-  if (auto intMemorySpace = dyn_cast_if_present<IntegerAttr>(memorySpace))
-    return intMemorySpace.getInt() == 1;
-  if (auto gpuMemorySpace =
-          dyn_cast_if_present<gpu::AddressSpaceAttr>(memorySpace))
-    return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
-  return false;
-}
-
 static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
   if (!memorySpace)
     return false;
@@ -1025,9 +1016,9 @@ void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // GlobalLoadAsyncToLDSOp
 //===----------------------------------------------------------------------===//
 
-static std::optional<int64_t> getTransferSizeInBits(Type transferType) {
+static std::optional<unsigned> getTransferSizeInBits(Type transferType) {
   Type elementType = transferType;
-  int64_t numElements = 1;
+  unsigned numElements = 1;
   if (auto vectorType = dyn_cast<VectorType>(transferType)) {
     elementType = vectorType.getElementType();
     numElements = vectorType.getNumElements();
@@ -1042,8 +1033,7 @@ static std::optional<int64_t> getTransferSizeInBits(Type transferType) {
 static LogicalResult
 verifyGlobalLoadAsyncToLDSLike(Operation *op, MemRefType srcType,
                                MemRefType dstType, OperandRange srcIndices,
-                               OperandRange dstIndices, Type transferType,
-                               bool requireExplicitGlobal) {
+                               OperandRange dstIndices, Type transferType) {
   if (srcType.getElementType() != dstType.getElementType())
     return op->emitOpError("source and destination element types must match");
 
@@ -1054,20 +1044,16 @@ verifyGlobalLoadAsyncToLDSLike(Operation *op, MemRefType srcType,
     return op->emitOpError(
         "destination index count must match destination memref rank");
 
-  std::optional<int64_t> transferSize = getTransferSizeInBits(transferType);
+  std::optional<unsigned> transferSize = getTransferSizeInBits(transferType);
   if (!transferSize)
     return op->emitOpError(
         "transfer type must be an integer, float, or vector of integers or "
         "floats");
 
-  if (!llvm::is_contained({8, 32, 64, 128}, *transferSize))
+  if (!llvm::is_contained({8u, 32u, 64u, 128u}, transferSize.value()))
     return op->emitOpError("transfer type size must be 8, 32, 64, or 128 bits");
 
-  bool hasValidGlobalMemorySpace =
-      requireExplicitGlobal
-          ? hasExplicitGlobalMemorySpace(srcType.getMemorySpace())
-          : hasGlobalMemorySpace(srcType.getMemorySpace());
-  if (!hasValidGlobalMemorySpace)
+  if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
     return op->emitOpError("source memory address space must be global");
 
   if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
@@ -1078,12 +1064,11 @@ verifyGlobalLoadAsyncToLDSLike(Operation *op, MemRefType srcType,
 }
 
 LogicalResult GlobalLoadAsyncToLDSOp::verify() {
-  MemRefType srcType = cast<MemRefType>(getSrc().getType());
-  MemRefType dstType = cast<MemRefType>(getDst().getType());
+  auto srcType = cast<MemRefType>(getSrc().getType());
+  auto dstType = cast<MemRefType>(getDst().getType());
   return verifyGlobalLoadAsyncToLDSLike(*this, srcType, dstType,
                                         getSrcIndices(), getDstIndices(),
-                                        getTransferType(),
-                                        /*requireExplicitGlobal=*/false);
+                                        getTransferType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1091,12 +1076,11 @@ LogicalResult GlobalLoadAsyncToLDSOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ClusterLoadAsyncToLDSOp::verify() {
-  MemRefType srcType = cast<MemRefType>(getSrc().getType());
-  MemRefType dstType = cast<MemRefType>(getDst().getType());
+  auto srcType = cast<MemRefType>(getSrc().getType());
+  auto dstType = cast<MemRefType>(getDst().getType());
   return verifyGlobalLoadAsyncToLDSLike(*this, srcType, dstType,
                                         getSrcIndices(), getDstIndices(),
-                                        getTransferType(),
-                                        /*requireExplicitGlobal=*/true);
+                                        getTransferType());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index cd21be8b8b5e9..1d00edfaedac0 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -415,28 +415,6 @@ func.func @cluster_load_async_to_lds_src_not_global(%idx1 : index, %mask : i32,
 
 // -----
 
-func.func @cluster_load_async_to_lds_src_default_memory_space(
-    %idx1 : index, %mask : i32, %mem1 : memref<32xf32>,
-    %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
-  // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op source memory address space must be global}}
-  amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
-    : f32, memref<32xf32>, memref<32xf32, #gpu.address_space<workgroup>>
-  func.return
-}
-
-// -----
-
-func.func @cluster_load_async_to_lds_src_memory_space_zero(
-    %idx1 : index, %mask : i32, %mem1 : memref<32xf32, 0>,
-    %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {
-  // expected-error at +1 {{'amdgpu.cluster_load_async_to_lds' op source memory address space must be global}}
-  amdgpu.cluster_load_async_to_lds %mem1[%idx1], %mem2[%idx1], %mask
-    : f32, memref<32xf32, 0>, memref<32xf32, #gpu.address_space<workgroup>>
-  func.return
-}
-
-// -----
-
 func.func @cluster_load_async_to_lds_src_index_count(%idx1 : index,
     %mask : i32, %mem1 : memref<32x32xf32, #gpu.address_space<global>>,
     %mem2 : memref<32xf32, #gpu.address_space<workgroup>>) {



More information about the Mlir-commits mailing list