[Mlir-commits] [mlir] [MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU (PR #133498)

Alan Li llvmlistbot at llvm.org
Wed Apr 2 17:59:43 PDT 2025


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/133498

>From ba384661402929671a0de82ffd813548af38f3de Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 25 Mar 2025 14:50:45 -0400
Subject: [PATCH 01/11] make it compiable.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 32 +++++++++++++++++++
 1 file changed, 32 insertions(+)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index c0b3e5540b1df..53323be7d6c5c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -668,6 +668,12 @@ def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
 
+// MI300: limit to only 4 bytes per elements.
+def GlobalLoadTypes : AnyTypeOf<[F16, F32, I8, SI8, UI8, I16, I32,
+                                VectorOfLengthAndType<[2], [F16, BF16, I16]>,
+                                VectorOfLengthAndType<[2, 4], [I8, SI8, UI8]>
+                               ]>;
+
 def AMDGPU_MFMAOp :
     AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
                         Pure]>,
@@ -765,4 +771,30 @@ def AMDGPU_WMMAOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_GlobalLoadLDSOp :
+    AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
+    Arguments<(ins
+                   Arg<AnyMemRef, "buffer to read from", [MemRead]>:$src,
+                   Variadic<I32>:$srcIndices,
+                   Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
+                   Variadic<I32>:$dstIndices
+                   )>,
+    Results<(outs)> {
+  let summary = "MLIR wrapper for CDNA mfma instructions";
+  let description = [{
+    The `amdgpu.mfma` op is an MLIR wrapper around intrinsics
+    for various `mfma` instructions in the CDNA architecture, which perform
+    multiple outer products in order to allow fast matrix multiplication.
+
+    The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
+
+    The 
+
+  }];
+  let assemblyFormat = [{
+    $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` `,` type($src) `,` type($dst)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // AMDGPU

>From 8e46990cc375af4a2dbf908da1cb79a3a93c138c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 25 Mar 2025 15:38:00 -0400
Subject: [PATCH 02/11] Adding a verifier.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  2 +-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 30 +++++++++++++++++++
 2 files changed, 31 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 53323be7d6c5c..b7cfed1f18c6c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -792,7 +792,7 @@ def AMDGPU_GlobalLoadLDSOp :
 
   }];
   let assemblyFormat = [{
-    $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` `,` type($src) `,` type($dst)
+    $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2b2a167b90c82..64003e9a2dca6 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -24,6 +25,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Metadata.h"
 
 #include <limits>
 #include <optional>
@@ -459,6 +461,34 @@ LogicalResult DPPOp::verify() {
   return success();
 }
 
+LogicalResult GlobalLoadLDSOp::verify() {
+  MemRefType srcType = cast<MemRefType>(getSrc().getType());
+  MemRefType dstType = cast<MemRefType>(getDst().getType());
+
+  if (!memref::isStaticShapeAndContiguousRowMajor(srcType) ||
+      !memref::isStaticShapeAndContiguousRowMajor(dstType))
+    return emitOpError(
+        "source and destination types must have static shape  and contiguous");
+
+  // Check $src and $dst element types are the same.
+  if (srcType.getElementType() != dstType.getElementType())
+    return emitOpError("source and destination element types must match");
+
+  // Check $src and $dst memory spaces.
+  auto srcAddrSpace = llvm::dyn_cast<IntegerAttr>(srcType.getMemorySpace());
+  auto dstAddrSpace = llvm::dyn_cast<IntegerAttr>(dstType.getMemorySpace());
+  if (!srcAddrSpace || srcAddrSpace.getInt() != 1)
+    return emitOpError("source memory address space must be Global");
+  if (dstAddrSpace.getInt() != 3)
+    return emitOpError("destination memory address space must be Workgroup");
+
+  // Check chunk size compatible with element type.
+  // TODO
+  auto dstChunkSize = dstType.getShape();
+
+  return success();
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES

>From c072e660b1147a4f0b439f41735edc3df2f07bc5 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 26 Mar 2025 22:30:13 -0400
Subject: [PATCH 03/11] checkpoint

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 13 ++--
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 76 ++++++++++++++++++-
 2 files changed, 82 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index b7cfed1f18c6c..e60eaa6d11897 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -771,12 +771,14 @@ def AMDGPU_WMMAOp :
   let hasVerifier = 1;
 }
 
+def GlobalLoadMemRefType : MemRefOf<[GlobalLoadTypes]>;
+
 def AMDGPU_GlobalLoadLDSOp :
     AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
     Arguments<(ins
-                   Arg<AnyMemRef, "buffer to read from", [MemRead]>:$src,
+                   Arg<GlobalLoadMemRefType, "buffer to read from", [MemRead]>:$src,
                    Variadic<I32>:$srcIndices,
-                   Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
+                   Arg<GlobalLoadMemRefType, "buffer to write to", [MemWrite]>:$dst,
                    Variadic<I32>:$dstIndices
                    )>,
     Results<(outs)> {
@@ -788,11 +790,12 @@ def AMDGPU_GlobalLoadLDSOp :
 
     The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
 
-    The 
-
+    The `$src`, along with its indices, points to the memory location this thread reads from.
+    The `$dst`, along with its indices, points to the memory location the subgroup of this thread
+    will write to.
   }];
   let assemblyFormat = [{
-    $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
+    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3acd470cff7f5..83e18a3d87b47 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,6 +903,78 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
+  GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(GlobalLoadLDSOp op, GlobalLoadLDSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
+    auto elemType = cast<MemRefType>(op.getDst().getType()).getElementType();
+    size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
+    if (elemSizeInBits % 8 != 0)
+      return op.emitOpError("element size must be a multiple of 8");
+    auto loadWidth = elemSizeInBits / 8;
+
+    // TODO: add chipset support check
+    if (chipset.majorVersion >= 12)
+      return op.emitOpError("TODO");
+
+    // TODO: fold this into chipset check.
+    // Currently only 1, 2, and 4 byte loads are supported.
+    if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
+      return op.emitOpError("unsupported element size");
+
+    Value src = adaptor.getSrc();
+    Value dst = adaptor.getDst();
+    Value memrefSrc = op.getSrc();
+    Value memrefDst = op.getDst();
+
+    // Collapse src memref with indices:
+    auto flattenIndex = [&](Value memref, MemRefType memrefType,
+                            ValueRange indices) -> std::optional<Value> {
+      MemRefDescriptor memRefDescriptor(memref);
+      int64_t offset = 0;
+      SmallVector<int64_t, 5> strides;
+      if (failed(memrefType.getStridesAndOffset(strides, offset)))
+        return {};
+      return getLinearIndexI32(rewriter, loc, memRefDescriptor, indices,
+                               strides);
+    };
+
+    // Source
+    auto optSrcIdx = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
+                                  op.getSrcIndices());
+    if (!optSrcIdx)
+      return op.emitOpError("failed to flatten source memref indices");
+    auto optDstIdx = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
+                                  op.getDstIndices());
+    if (!optDstIdx)
+      return op.emitOpError("failed to flatten destination memref indices");
+
+    Type srcPtrType =
+        LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
+    Type dstPtrType =
+        LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
+    Value srcPtr = rewriter.create<LLVM::GEPOp>(
+        loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
+    
+    Value dstPtr = rewriter.create<LLVM::GEPOp>(
+        loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
+
+    rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
+        op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
+        createI32Constant(rewriter, loc, 0),
+        createI32Constant(rewriter, loc, 0));
+
+    return success();
+  }
+};
+
 namespace {
 struct ExtPackedFp8OpLowering final
     : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
@@ -1286,6 +1358,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
            MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
-           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
-                                                                      chipset);
+           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+           GlobalLoadLDSOpLowering>(converter, chipset);
 }

>From 564ebc8cd29ffe7437e1b21aacb626874e3a2453 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 27 Mar 2025 10:15:56 -0400
Subject: [PATCH 04/11] Make it work

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 47 +++++++++++--------
 1 file changed, 28 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 83e18a3d87b47..4169eb7825c95 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,7 +903,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
-struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
+struct GlobalLoadLDSOpLowering
+    : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
   GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
 
@@ -918,6 +919,10 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
     size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
     if (elemSizeInBits % 8 != 0)
       return op.emitOpError("element size must be a multiple of 8");
+
+    // TODO: instead of only transfering one element per thread, we could
+    // augment it to transfer multiple elements per thread by issuing multiple
+    // `global_load_lds` instructions.
     auto loadWidth = elemSizeInBits / 8;
 
     // TODO: add chipset support check
@@ -934,37 +939,41 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
     Value memrefSrc = op.getSrc();
     Value memrefDst = op.getDst();
 
-    // Collapse src memref with indices:
-    auto flattenIndex = [&](Value memref, MemRefType memrefType,
-                            ValueRange indices) -> std::optional<Value> {
+    // Collapse src memref with indices, returns the base pointer and linearized
+    // index.
+    auto flattenIndex =
+        [&](Value memref, MemRefType memrefType,
+            ValueRange indices) -> std::optional<std::pair<Value, Value>> {
       MemRefDescriptor memRefDescriptor(memref);
       int64_t offset = 0;
       SmallVector<int64_t, 5> strides;
       if (failed(memrefType.getStridesAndOffset(strides, offset)))
         return {};
-      return getLinearIndexI32(rewriter, loc, memRefDescriptor, indices,
-                               strides);
+      return std::make_pair(
+          memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
+                                     memrefType),
+          getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
     };
 
     // Source
-    auto optSrcIdx = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
-                                  op.getSrcIndices());
-    if (!optSrcIdx)
+    auto optSrcBuffer = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
+                                     op.getSrcIndices());
+    if (!optSrcBuffer)
       return op.emitOpError("failed to flatten source memref indices");
-    auto optDstIdx = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
-                                  op.getDstIndices());
-    if (!optDstIdx)
+    auto optDstBuffer = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
+                                     op.getDstIndices());
+    if (!optDstBuffer)
       return op.emitOpError("failed to flatten destination memref indices");
 
-    Type srcPtrType =
-        LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
-    Type dstPtrType =
-        LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
+    Type srcPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
+    Type dstPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
     Value srcPtr = rewriter.create<LLVM::GEPOp>(
-        loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
-    
+        loc, srcPtrType, elemType, optSrcBuffer->first,
+        ArrayRef<Value>({optSrcBuffer->second}));
+
     Value dstPtr = rewriter.create<LLVM::GEPOp>(
-        loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
+        loc, dstPtrType, elemType, optDstBuffer->first,
+        ArrayRef<Value>({optDstBuffer->second}));
 
     rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
         op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),

>From 92a1ef9d5fa73dbb5c486a1e7dc29290d48b2067 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 27 Mar 2025 13:04:27 -0400
Subject: [PATCH 05/11] update AMDGPU description.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 18 ++++++-----
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 30 ++++++++-----------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  5 ----
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 23 ++++++++++++++
 4 files changed, 46 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e60eaa6d11897..606cf69537ca1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -784,15 +784,19 @@ def AMDGPU_GlobalLoadLDSOp :
     Results<(outs)> {
   let summary = "MLIR wrapper for CDNA mfma instructions";
   let description = [{
-    The `amdgpu.mfma` op is an MLIR wrapper around intrinsics
-    for various `mfma` instructions in the CDNA architecture, which perform
-    multiple outer products in order to allow fast matrix multiplication.
-
-    The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
-
-    The `$src`, along with its indices, points to the memory location this thread reads from.
+    The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
+
+    Operands:
+    * `$src`: global memory memref to read from.
+    * `$srcIndices`: indices into `$src` to read from for this thread.
+    * `$dst`: LDS memory memref to write to.
+    * `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
+      number of subgroup size of elements will be written contiguously to `$dst[$dstIndices]`.
+    
     The `$dst`, along with its indices, points to the memory location the subgroup of this thread
     will write to.
+  
+    Note: only enabled for gfx942 and later.
   }];
   let assemblyFormat = [{
     $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4169eb7825c95..ffb347ade326c 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -925,23 +925,16 @@ struct GlobalLoadLDSOpLowering
     // `global_load_lds` instructions.
     auto loadWidth = elemSizeInBits / 8;
 
-    // TODO: add chipset support check
-    if (chipset.majorVersion >= 12)
-      return op.emitOpError("TODO");
+    const Chipset GlobalLoadEnabled{9, 0x4, 0x0};
+    if (chipset < GlobalLoadEnabled)
+      return op.emitOpError("chipset not supported");
 
-    // TODO: fold this into chipset check.
     // Currently only 1, 2, and 4 byte loads are supported.
     if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
-      return op.emitOpError("unsupported element size");
+      return op.emitOpError("chipset unsupported element size");
 
-    Value src = adaptor.getSrc();
-    Value dst = adaptor.getDst();
-    Value memrefSrc = op.getSrc();
-    Value memrefDst = op.getDst();
-
-    // Collapse src memref with indices, returns the base pointer and linearized
-    // index.
-    auto flattenIndex =
+    // Return pair of {base pointer, linearized index}.
+    auto getBasePtrAndLinearizedIndex =
         [&](Value memref, MemRefType memrefType,
             ValueRange indices) -> std::optional<std::pair<Value, Value>> {
       MemRefDescriptor memRefDescriptor(memref);
@@ -955,13 +948,14 @@ struct GlobalLoadLDSOpLowering
           getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
     };
 
-    // Source
-    auto optSrcBuffer = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
-                                     op.getSrcIndices());
+    auto optSrcBuffer = getBasePtrAndLinearizedIndex(
+        adaptor.getSrc(), cast<MemRefType>(op.getSrc().getType()),
+        op.getSrcIndices());
     if (!optSrcBuffer)
       return op.emitOpError("failed to flatten source memref indices");
-    auto optDstBuffer = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
-                                     op.getDstIndices());
+    auto optDstBuffer = getBasePtrAndLinearizedIndex(
+        adaptor.getDst(), cast<MemRefType>(op.getDst().getType()),
+        op.getDstIndices());
     if (!optDstBuffer)
       return op.emitOpError("failed to flatten destination memref indices");
 
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 64003e9a2dca6..efec76c91ea23 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -481,11 +481,6 @@ LogicalResult GlobalLoadLDSOp::verify() {
     return emitOpError("source memory address space must be Global");
   if (dstAddrSpace.getInt() != 3)
     return emitOpError("destination memory address space must be Workgroup");
-
-  // Check chunk size compatible with element type.
-  // TODO
-  auto dstChunkSize = dstType.getShape();
-
   return success();
 }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 8871b2ce0eadb..65515dbe30b29 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -9,6 +9,7 @@
 // test pass doesn't set up the GPU address space conversions.
 
 #gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
 
 // CHECK-LABEL: func @fat_raw_buffer_cast
 func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
@@ -461,3 +462,25 @@ func.func @sched_barrier() {
   amdgpu.sched_barrier allow = <valu|all_vmem>
   func.return
 }
+
+// CHECK-LABEL: func @global_load_to_rocdl_f32
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
+func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
+  %c0 = arith.constant 0 : i32
+  %c12 = arith.constant 12 : i32
+  %c32 = arith.constant 32 : i32
+  %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
+  // GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
+  // GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
+  // GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+  // GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
+  // GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+  // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[GLOBAL_OFFSET:.*]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
+  // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[LDS_OFFSET:.*]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+  // GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+  amdgpu.global_load %global[%c12, %c0], %alloc[%c32, %c0] : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  func.return
+}

>From a17e854c5f2f72e486f7fc53a677afa1e708c700 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 11:21:10 -0400
Subject: [PATCH 06/11] Address comments.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 16 +++---------
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 15 ++++++-----
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 25 ++++++++++---------
 3 files changed, 24 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 606cf69537ca1..70d5f5de416d1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -668,12 +668,6 @@ def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
 
-// MI300: limit to only 4 bytes per elements.
-def GlobalLoadTypes : AnyTypeOf<[F16, F32, I8, SI8, UI8, I16, I32,
-                                VectorOfLengthAndType<[2], [F16, BF16, I16]>,
-                                VectorOfLengthAndType<[2, 4], [I8, SI8, UI8]>
-                               ]>;
-
 def AMDGPU_MFMAOp :
     AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
                         Pure]>,
@@ -771,14 +765,12 @@ def AMDGPU_WMMAOp :
   let hasVerifier = 1;
 }
 
-def GlobalLoadMemRefType : MemRefOf<[GlobalLoadTypes]>;
-
-def AMDGPU_GlobalLoadLDSOp :
-    AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
+def AMDGPU_GatherToLDSOp :
+    AMDGPU_Op<"gather_to_lds", [SameVariadicOperandSize]>,
     Arguments<(ins
-                   Arg<GlobalLoadMemRefType, "buffer to read from", [MemRead]>:$src,
+                   Arg<AnyMemRef, "buffer to gather from", [MemRead]>:$src,
                    Variadic<I32>:$srcIndices,
-                   Arg<GlobalLoadMemRefType, "buffer to write to", [MemWrite]>:$dst,
+                   Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
                    Variadic<I32>:$dstIndices
                    )>,
     Results<(outs)> {
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ffb347ade326c..5ade785ea9b93 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,15 +903,15 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
-struct GlobalLoadLDSOpLowering
-    : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
-  GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
-      : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
+struct GatherToLDSOpLowering
+    : public ConvertOpToLLVMPattern<GatherToLDSOp> {
+  GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
 
   Chipset chipset;
 
   LogicalResult
-  matchAndRewrite(GlobalLoadLDSOp op, GlobalLoadLDSOpAdaptor adaptor,
+  matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
 
@@ -925,8 +925,7 @@ struct GlobalLoadLDSOpLowering
     // `global_load_lds` instructions.
     auto loadWidth = elemSizeInBits / 8;
 
-    const Chipset GlobalLoadEnabled{9, 0x4, 0x0};
-    if (chipset < GlobalLoadEnabled)
+    if (chipset < kGfx942)
       return op.emitOpError("chipset not supported");
 
     // Currently only 1, 2, and 4 byte loads are supported.
@@ -1362,5 +1361,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
            MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
            PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
-           GlobalLoadLDSOpLowering>(converter, chipset);
+           GatherToLDSOpLowering>(converter, chipset);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index efec76c91ea23..27e0cfb8044b3 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -25,7 +25,6 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/TypeSwitch.h"
-#include "llvm/IR/Metadata.h"
 
 #include <limits>
 #include <optional>
@@ -461,26 +460,28 @@ LogicalResult DPPOp::verify() {
   return success();
 }
 
-LogicalResult GlobalLoadLDSOp::verify() {
+LogicalResult GatherToLDSOp::verify() {
   MemRefType srcType = cast<MemRefType>(getSrc().getType());
   MemRefType dstType = cast<MemRefType>(getDst().getType());
 
-  if (!memref::isStaticShapeAndContiguousRowMajor(srcType) ||
-      !memref::isStaticShapeAndContiguousRowMajor(dstType))
+  if (!memref::isStaticShapeAndContiguousRowMajor(dstType))
     return emitOpError(
-        "source and destination types must have static shape  and contiguous");
+        "destination types must have static shape  and contiguous");
 
+  auto elemType = srcType.getElementType();
   // Check $src and $dst element types are the same.
-  if (srcType.getElementType() != dstType.getElementType())
+  if (elemType != dstType.getElementType())
     return emitOpError("source and destination element types must match");
 
-  // Check $src and $dst memory spaces.
-  auto srcAddrSpace = llvm::dyn_cast<IntegerAttr>(srcType.getMemorySpace());
-  auto dstAddrSpace = llvm::dyn_cast<IntegerAttr>(dstType.getMemorySpace());
-  if (!srcAddrSpace || srcAddrSpace.getInt() != 1)
-    return emitOpError("source memory address space must be Global");
-  if (dstAddrSpace.getInt() != 3)
+  // Element type sizes should be 1, 2, or 4 bytes.
+  auto elemSize = elemType.getIntOrFloatBitWidth();
+  if (elemSize != 8 && elemSize != 16 && elemSize != 32)
+    return emitOpError("source and destination element types must be 8, 16, "
+                       "or 32 bits");
+
+  if (!gpu::GPUDialect::hasWorkgroupMemoryAddressSpace(dstType))
     return emitOpError("destination memory address space must be Workgroup");
+
   return success();
 }
 

>From 4ed20068c55450581929686679197ffd999c68ab Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 14:12:42 -0400
Subject: [PATCH 07/11] update lowering

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 10 ++-
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 73 ++++++++-----------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 49 +++++++++----
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 12 +--
 4 files changed, 78 insertions(+), 66 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 70d5f5de416d1..350b184dae6a2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -769,9 +769,10 @@ def AMDGPU_GatherToLDSOp :
     AMDGPU_Op<"gather_to_lds", [SameVariadicOperandSize]>,
     Arguments<(ins
                    Arg<AnyMemRef, "buffer to gather from", [MemRead]>:$src,
-                   Variadic<I32>:$srcIndices,
+                   Variadic<Index>:$srcIndices,
                    Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
-                   Variadic<I32>:$dstIndices
+                   Variadic<Index>:$dstIndices,
+                   TypeAttr:$transferType
                    )>,
     Results<(outs)> {
   let summary = "MLIR wrapper for CDNA mfma instructions";
@@ -784,7 +785,10 @@ def AMDGPU_GatherToLDSOp :
     * `$dst`: LDS memory memref to write to.
     * `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
       number of subgroup size of elements will be written contiguously to `$dst[$dstIndices]`.
-    
+    * `$transferType`: type of the data to be transferred by each thread. This is used to determine
+      the size of the data to be transferred and the number of threads in the subgroup.
+      The transfer type must be a scalar type or a vector type with a single element type.
+
     The `$dst`, along with its indices, points to the memory location the subgroup of this thread
     will write to.
   
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5ade785ea9b93..a644c24a8d6bc 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -913,60 +913,49 @@ struct GatherToLDSOpLowering
   LogicalResult
   matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx942)
+      return op.emitOpError("chipset not supported");
+
     Location loc = op.getLoc();
 
-    auto elemType = cast<MemRefType>(op.getDst().getType()).getElementType();
-    size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
-    if (elemSizeInBits % 8 != 0)
-      return op.emitOpError("element size must be a multiple of 8");
+    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+    auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
 
     // TODO: instead of only transfering one element per thread, we could
     // augment it to transfer multiple elements per thread by issuing multiple
     // `global_load_lds` instructions.
-    auto loadWidth = elemSizeInBits / 8;
-
-    if (chipset < kGfx942)
-      return op.emitOpError("chipset not supported");
+    size_t loadWidth;
+    Type transferType = op.getTransferType();
+    if (auto transferVectorType = dyn_cast<VectorType>(transferType))
+      loadWidth = transferVectorType.getNumElements() *
+                  transferVectorType.getElementTypeBitWidth() / 8;
+    else
+      loadWidth = transferType.getIntOrFloatBitWidth() / 8;
 
     // Currently only 1, 2, and 4 byte loads are supported.
-    if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
+    if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
       return op.emitOpError("chipset unsupported element size");
 
-    // Return pair of {base pointer, linearized index}.
-    auto getBasePtrAndLinearizedIndex =
-        [&](Value memref, MemRefType memrefType,
-            ValueRange indices) -> std::optional<std::pair<Value, Value>> {
-      MemRefDescriptor memRefDescriptor(memref);
-      int64_t offset = 0;
-      SmallVector<int64_t, 5> strides;
-      if (failed(memrefType.getStridesAndOffset(strides, offset)))
-        return {};
-      return std::make_pair(
-          memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
-                                     memrefType),
-          getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
+    auto convertIndices =
+        [&](ValueRange indices) -> SmallVector<Value, 4> {
+      SmallVector<Value, 4> convertedIndices;
+      
+      for (Value index : indices) {
+        Type convertedType = getTypeConverter()->convertType(index.getType());
+        auto convertedIndex = rewriter.create<LLVM::ConstantOp>(
+            loc, convertedType,
+            rewriter.getIntegerAttr(convertedType, 0));
+        convertedIndices.push_back(convertedIndex);
+      }
+      return convertedIndices;
     };
 
-    auto optSrcBuffer = getBasePtrAndLinearizedIndex(
-        adaptor.getSrc(), cast<MemRefType>(op.getSrc().getType()),
-        op.getSrcIndices());
-    if (!optSrcBuffer)
-      return op.emitOpError("failed to flatten source memref indices");
-    auto optDstBuffer = getBasePtrAndLinearizedIndex(
-        adaptor.getDst(), cast<MemRefType>(op.getDst().getType()),
-        op.getDstIndices());
-    if (!optDstBuffer)
-      return op.emitOpError("failed to flatten destination memref indices");
-
-    Type srcPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
-    Type dstPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
-    Value srcPtr = rewriter.create<LLVM::GEPOp>(
-        loc, srcPtrType, elemType, optSrcBuffer->first,
-        ArrayRef<Value>({optSrcBuffer->second}));
-
-    Value dstPtr = rewriter.create<LLVM::GEPOp>(
-        loc, dstPtrType, elemType, optDstBuffer->first,
-        ArrayRef<Value>({optDstBuffer->second}));
+    Value srcPtr =
+        getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
+                             convertIndices(op.getSrcIndices()), rewriter);
+    Value dstPtr =
+        getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
+                             convertIndices(op.getDstIndices()), rewriter);
 
     rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
         op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 27e0cfb8044b3..605f0a92224e9 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/DerivedTypes.h"
 
 #include <limits>
 #include <optional>
@@ -113,21 +114,30 @@ LogicalResult FatRawBufferCastOp::verify() {
   return success();
 }
 
+static bool hasGlobalMemorySpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return true;
+  else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
+    return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
+  else if (auto gpuMemorySpace =
+          llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+  return false;
+}
+
+static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
+  if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
+    return intMemorySpace.getInt() == 3;
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // RawBuffer*Op
 //===----------------------------------------------------------------------===//
 template <typename T>
 static LogicalResult verifyRawBufferOp(T &op) {
   MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
-  Attribute memorySpace = bufferType.getMemorySpace();
-  bool isGlobal = false;
-  if (!memorySpace)
-    isGlobal = true;
-  else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
-    isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
-  else if (auto gpuMemorySpace =
-               llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
-    isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+  bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
 
   if (!isGlobal)
     return op.emitOpError(
@@ -473,13 +483,22 @@ LogicalResult GatherToLDSOp::verify() {
   if (elemType != dstType.getElementType())
     return emitOpError("source and destination element types must match");
 
-  // Element type sizes should be 1, 2, or 4 bytes.
-  auto elemSize = elemType.getIntOrFloatBitWidth();
-  if (elemSize != 8 && elemSize != 16 && elemSize != 32)
-    return emitOpError("source and destination element types must be 8, 16, "
-                       "or 32 bits");
+  // copy type sizes should be 1, 2, or 4 bytes.
+  auto transferType = getTransferType();
+  size_t transferSize;
+  if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
+    transferSize = vectorTransfer.getNumElements() *
+                   vectorTransfer.getElementTypeBitWidth();
+  } else {
+    transferSize = transferType.getIntOrFloatBitWidth();
+  }
+  if (transferSize != 8 && transferSize != 16 && transferSize != 32)
+    return emitOpError("Transfering type size must be 8, 16, or 32 bits");
+
+  if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
+    return emitOpError("source memory address space must be Global");
 
-  if (!gpu::GPUDialect::hasWorkgroupMemoryAddressSpace(dstType))
+  if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
     return emitOpError("destination memory address space must be Workgroup");
 
   return success();
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 65515dbe30b29..5dad6b75f9ec8 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -466,21 +466,21 @@ func.func @sched_barrier() {
 // CHECK-LABEL: func @global_load_to_rocdl_f32
 // CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
 func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
-  %c0 = arith.constant 0 : i32
-  %c12 = arith.constant 12 : i32
-  %c32 = arith.constant 32 : i32
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c32 = arith.constant 32 : index
   %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
   // GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
   // GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
   // GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
   // GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
   // GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
-  // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[GLOBAL_OFFSET:.*]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
-  // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[LDS_OFFSET:.*]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+  // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]]
+  // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]]
   // GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
   // GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
   // GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
   // GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
-  amdgpu.global_load %global[%c12, %c0], %alloc[%c32, %c0] : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32} : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
   func.return
 }

>From 81bb8bc2aa6b75aabe6a0d50436cabca90287691 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 14:37:42 -0400
Subject: [PATCH 08/11] update test files.

---
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 23 ----------------
 .../Conversion/AMDGPUToROCDL/load_lds.mlir    | 26 +++++++++++++++++++
 2 files changed, 26 insertions(+), 23 deletions(-)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir

diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 5dad6b75f9ec8..8871b2ce0eadb 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -9,7 +9,6 @@
 // test pass doesn't set up the GPU address space conversions.
 
 #gpu_global_addrspace = 1
-#gpu_lds_addrspace = 3
 
 // CHECK-LABEL: func @fat_raw_buffer_cast
 func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
@@ -462,25 +461,3 @@ func.func @sched_barrier() {
   amdgpu.sched_barrier allow = <valu|all_vmem>
   func.return
 }
-
-// CHECK-LABEL: func @global_load_to_rocdl_f32
-// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
-func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
-  %c0 = arith.constant 0 : index
-  %c12 = arith.constant 12 : index
-  %c32 = arith.constant 32 : index
-  %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
-  // GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
-  // GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
-  // GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
-  // GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
-  // GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
-  // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]]
-  // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]]
-  // GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
-  // GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
-  // GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-  // GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
-  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32} : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
-  func.return
-}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
new file mode 100644
index 0000000000000..507928eaf0d78
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s
+
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+
+// CHECK-LABEL: func @global_load_to_rocdl_f32
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
+func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c32 = arith.constant 32 : index
+  %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]]
+  // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]]
+  // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32} : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  func.return
+}

>From 73629f4f09c581e027943804270c582621bc7192 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 16:09:58 -0400
Subject: [PATCH 09/11] linting

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 11 ++++-------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp        |  2 +-
 2 files changed, 5 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index a644c24a8d6bc..5e8cbf262f0eb 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,8 +903,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
-struct GatherToLDSOpLowering
-    : public ConvertOpToLLVMPattern<GatherToLDSOp> {
+struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
   GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
 
@@ -936,15 +935,13 @@ struct GatherToLDSOpLowering
     if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
       return op.emitOpError("chipset unsupported element size");
 
-    auto convertIndices =
-        [&](ValueRange indices) -> SmallVector<Value, 4> {
+    auto convertIndices = [&](ValueRange indices) -> SmallVector<Value, 4> {
       SmallVector<Value, 4> convertedIndices;
-      
+
       for (Value index : indices) {
         Type convertedType = getTypeConverter()->convertType(index.getType());
         auto convertedIndex = rewriter.create<LLVM::ConstantOp>(
-            loc, convertedType,
-            rewriter.getIntegerAttr(convertedType, 0));
+            loc, convertedType, rewriter.getIntegerAttr(convertedType, 0));
         convertedIndices.push_back(convertedIndex);
       }
       return convertedIndices;
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 605f0a92224e9..00d67f58aee1c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -120,7 +120,7 @@ static bool hasGlobalMemorySpace(Attribute memorySpace) {
   else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
     return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
   else if (auto gpuMemorySpace =
-          llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+               llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
     return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
   return false;
 }

>From b48370161416b7ddce8734664162d282ca4c2bcb Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 20:52:19 -0400
Subject: [PATCH 10/11] Update tests.

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  22 +--
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |   7 +-
 .../Conversion/AMDGPUToROCDL/load_lds.mlir    | 137 +++++++++++++++++-
 3 files changed, 139 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5e8cbf262f0eb..7efe70fd16a04 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -935,24 +935,10 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
     if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
       return op.emitOpError("chipset unsupported element size");
 
-    auto convertIndices = [&](ValueRange indices) -> SmallVector<Value, 4> {
-      SmallVector<Value, 4> convertedIndices;
-
-      for (Value index : indices) {
-        Type convertedType = getTypeConverter()->convertType(index.getType());
-        auto convertedIndex = rewriter.create<LLVM::ConstantOp>(
-            loc, convertedType, rewriter.getIntegerAttr(convertedType, 0));
-        convertedIndices.push_back(convertedIndex);
-      }
-      return convertedIndices;
-    };
-
-    Value srcPtr =
-        getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
-                             convertIndices(op.getSrcIndices()), rewriter);
-    Value dstPtr =
-        getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
-                             convertIndices(op.getDstIndices()), rewriter);
+    Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
+                                        (adaptor.getSrcIndices()), rewriter);
+    Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
+                                        (adaptor.getDstIndices()), rewriter);
 
     rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
         op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 00d67f58aee1c..6309dac761e4f 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -117,10 +117,9 @@ LogicalResult FatRawBufferCastOp::verify() {
 static bool hasGlobalMemorySpace(Attribute memorySpace) {
   if (!memorySpace)
     return true;
-  else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
+  if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
     return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
-  else if (auto gpuMemorySpace =
-               llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+  if (auto gpuMemorySpace = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
     return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
   return false;
 }
@@ -128,6 +127,8 @@ static bool hasGlobalMemorySpace(Attribute memorySpace) {
 static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
   if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
     return intMemorySpace.getInt() == 3;
+  if (auto gpuMemorySpace = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
   return false;
 }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
index 507928eaf0d78..c887f6504172a 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -10,17 +10,142 @@ func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_add
   %c12 = arith.constant 12 : index
   %c32 = arith.constant 32 : index
   %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
-  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
-  // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
-  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+  // CHECK: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+  // CHECK: %[[C32:.*]] = arith.constant 32 : index
+  // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+  // CHECK: %[[ALLOC:.*]] = memref.alloc()
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
+  // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
+  
+  // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+  // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
+  // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+
+  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
+  // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32}
+    : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  func.return
+}
+
+// CHECK-LABEL: func @global_load_to_rocdl_i8
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi8, 1>)
+func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrspace>) {
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+  // CHECK: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+  // CHECK: %[[C32:.*]] = arith.constant 32 : index
+  // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+  // CHECK: %[[ALLOC:.*]] = memref.alloc()
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
+  // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
+  
+  // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+  // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
+  // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+
+  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
+  // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C1]], %[[C0]], %[[C0_2]]
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c32 = arith.constant 32 : index
+  %alloc = memref.alloc() : memref<64x64xi8, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = i8}
+    : memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace>
+  func.return
+}
+
+// CHECK-LABEL: func @global_load_to_rocdl_vec
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi16, 1>)
+func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_addrspace>) {
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+  // CHECK: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+  // CHECK: %[[C32:.*]] = arith.constant 32 : index
+  // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+  // CHECK: %[[ALLOC:.*]] = memref.alloc()
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
   // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
-  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]]
+  
+  // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+  // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
   // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
-  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]]
+
+  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
   // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
-  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32} : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c32 = arith.constant 32 : index
+  %alloc = memref.alloc() : memref<64x128xi16, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = vector<2 x i16>}
+    : memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace>
   func.return
 }
+
+
+// CHECK-LABEL: func @global_load_to_rocdl_dynamic_indices
+// CHECK-SAME: (%[[ARG0:.*]]: memref<512xi32, 1>, %[[SRC_IDX:.*]]: index, %[[DST_IDX:.*]]: index)
+func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_global_addrspace>, %src_idx : index, %dst_idx : index) {
+  // CHECK: %[[DSTIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[DST_IDX]]
+  // CHECK: %[[SRCIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC_IDX]]
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  // CHECK: %[[ALLOC:.*]] = memref.alloc()
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
+  // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]]
+  // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX_CAST]]]
+  // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+  %alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0] {transferType = i32}
+    : memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
+  func.return
+}
\ No newline at end of file

>From 1a40d6c20698052a569c8f1a49425ed6acb3834c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 20:59:25 -0400
Subject: [PATCH 11/11] update again

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td    |  2 +-
 mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir | 16 ++++++++--------
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 350b184dae6a2..52753e2faf901 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -795,7 +795,7 @@ def AMDGPU_GatherToLDSOp :
     Note: only enabled for gfx942 and later.
   }];
   let assemblyFormat = [{
-    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
+    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
index c887f6504172a..160e5b203ed1f 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -39,8 +39,8 @@ func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_add
   // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
-  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32}
-    : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+    : f32, memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
   func.return
 }
 
@@ -80,8 +80,8 @@ func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrs
   %c12 = arith.constant 12 : index
   %c32 = arith.constant 32 : index
   %alloc = memref.alloc() : memref<64x64xi8, #gpu_lds_addrspace>
-  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = i8}
-    : memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+    : i8, memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace>
   func.return
 }
 
@@ -121,8 +121,8 @@ func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_add
   %c12 = arith.constant 12 : index
   %c32 = arith.constant 32 : index
   %alloc = memref.alloc() : memref<64x128xi16, #gpu_lds_addrspace>
-  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = vector<2 x i16>}
-    : memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+    : vector<2 x i16>, memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace>
   func.return
 }
 
@@ -145,7 +145,7 @@ func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_g
   // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
   %alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
   %c0 = arith.constant 0 : index
-  amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0] {transferType = i32}
-    : memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0]
+    : i32, memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
   func.return
 }
\ No newline at end of file



More information about the Mlir-commits mailing list