[Mlir-commits] [mlir] [MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU (PR #133498)

Alan Li llvmlistbot at llvm.org
Mon Apr 7 10:59:28 PDT 2025


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/133498

>From ba384661402929671a0de82ffd813548af38f3de Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 25 Mar 2025 14:50:45 -0400
Subject: [PATCH 01/12] make it compiable.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 32 +++++++++++++++++++
 1 file changed, 32 insertions(+)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index c0b3e5540b1df..53323be7d6c5c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -668,6 +668,12 @@ def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
 
+// MI300: limit to only 4 bytes per elements.
+def GlobalLoadTypes : AnyTypeOf<[F16, F32, I8, SI8, UI8, I16, I32,
+                                VectorOfLengthAndType<[2], [F16, BF16, I16]>,
+                                VectorOfLengthAndType<[2, 4], [I8, SI8, UI8]>
+                               ]>;
+
 def AMDGPU_MFMAOp :
     AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
                         Pure]>,
@@ -765,4 +771,30 @@ def AMDGPU_WMMAOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_GlobalLoadLDSOp :
+    AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
+    Arguments<(ins
+                   Arg<AnyMemRef, "buffer to read from", [MemRead]>:$src,
+                   Variadic<I32>:$srcIndices,
+                   Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
+                   Variadic<I32>:$dstIndices
+                   )>,
+    Results<(outs)> {
+  let summary = "MLIR wrapper for CDNA mfma instructions";
+  let description = [{
+    The `amdgpu.mfma` op is an MLIR wrapper around intrinsics
+    for various `mfma` instructions in the CDNA architecture, which perform
+    multiple outer products in order to allow fast matrix multiplication.
+
+    The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
+
+    The 
+
+  }];
+  let assemblyFormat = [{
+    $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` `,` type($src) `,` type($dst)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // AMDGPU

>From 8e46990cc375af4a2dbf908da1cb79a3a93c138c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 25 Mar 2025 15:38:00 -0400
Subject: [PATCH 02/12] Adding a verifier.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  2 +-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 30 +++++++++++++++++++
 2 files changed, 31 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 53323be7d6c5c..b7cfed1f18c6c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -792,7 +792,7 @@ def AMDGPU_GlobalLoadLDSOp :
 
   }];
   let assemblyFormat = [{
-    $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` `,` type($src) `,` type($dst)
+    $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2b2a167b90c82..64003e9a2dca6 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -24,6 +25,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Metadata.h"
 
 #include <limits>
 #include <optional>
@@ -459,6 +461,34 @@ LogicalResult DPPOp::verify() {
   return success();
 }
 
+LogicalResult GlobalLoadLDSOp::verify() {
+  MemRefType srcType = cast<MemRefType>(getSrc().getType());
+  MemRefType dstType = cast<MemRefType>(getDst().getType());
+
+  if (!memref::isStaticShapeAndContiguousRowMajor(srcType) ||
+      !memref::isStaticShapeAndContiguousRowMajor(dstType))
+    return emitOpError(
+        "source and destination types must have static shape  and contiguous");
+
+  // Check $src and $dst element types are the same.
+  if (srcType.getElementType() != dstType.getElementType())
+    return emitOpError("source and destination element types must match");
+
+  // Check $src and $dst memory spaces.
+  auto srcAddrSpace = llvm::dyn_cast<IntegerAttr>(srcType.getMemorySpace());
+  auto dstAddrSpace = llvm::dyn_cast<IntegerAttr>(dstType.getMemorySpace());
+  if (!srcAddrSpace || srcAddrSpace.getInt() != 1)
+    return emitOpError("source memory address space must be Global");
+  if (dstAddrSpace.getInt() != 3)
+    return emitOpError("destination memory address space must be Workgroup");
+
+  // Check chunk size compatible with element type.
+  // TODO
+  auto dstChunkSize = dstType.getShape();
+
+  return success();
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES

>From c072e660b1147a4f0b439f41735edc3df2f07bc5 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 26 Mar 2025 22:30:13 -0400
Subject: [PATCH 03/12] checkpoint

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 13 ++--
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 76 ++++++++++++++++++-
 2 files changed, 82 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index b7cfed1f18c6c..e60eaa6d11897 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -771,12 +771,14 @@ def AMDGPU_WMMAOp :
   let hasVerifier = 1;
 }
 
+def GlobalLoadMemRefType : MemRefOf<[GlobalLoadTypes]>;
+
 def AMDGPU_GlobalLoadLDSOp :
     AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
     Arguments<(ins
-                   Arg<AnyMemRef, "buffer to read from", [MemRead]>:$src,
+                   Arg<GlobalLoadMemRefType, "buffer to read from", [MemRead]>:$src,
                    Variadic<I32>:$srcIndices,
-                   Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
+                   Arg<GlobalLoadMemRefType, "buffer to write to", [MemWrite]>:$dst,
                    Variadic<I32>:$dstIndices
                    )>,
     Results<(outs)> {
@@ -788,11 +790,12 @@ def AMDGPU_GlobalLoadLDSOp :
 
     The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
 
-    The 
-
+    The `$src`, along with its indices, points to the memory location this thread reads from.
+    The `$dst`, along with its indices, points to the memory location the subgroup of this thread
+    will write to.
   }];
   let assemblyFormat = [{
-    $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
+    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3acd470cff7f5..83e18a3d87b47 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,6 +903,78 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
+  GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(GlobalLoadLDSOp op, GlobalLoadLDSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
+    auto elemType = cast<MemRefType>(op.getDst().getType()).getElementType();
+    size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
+    if (elemSizeInBits % 8 != 0)
+      return op.emitOpError("element size must be a multiple of 8");
+    auto loadWidth = elemSizeInBits / 8;
+
+    // TODO: add chipset support check
+    if (chipset.majorVersion >= 12)
+      return op.emitOpError("TODO");
+
+    // TODO: fold this into chipset check.
+    // Currently only 1, 2, and 4 byte loads are supported.
+    if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
+      return op.emitOpError("unsupported element size");
+
+    Value src = adaptor.getSrc();
+    Value dst = adaptor.getDst();
+    Value memrefSrc = op.getSrc();
+    Value memrefDst = op.getDst();
+
+    // Collapse src memref with indices:
+    auto flattenIndex = [&](Value memref, MemRefType memrefType,
+                            ValueRange indices) -> std::optional<Value> {
+      MemRefDescriptor memRefDescriptor(memref);
+      int64_t offset = 0;
+      SmallVector<int64_t, 5> strides;
+      if (failed(memrefType.getStridesAndOffset(strides, offset)))
+        return {};
+      return getLinearIndexI32(rewriter, loc, memRefDescriptor, indices,
+                               strides);
+    };
+
+    // Source
+    auto optSrcIdx = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
+                                  op.getSrcIndices());
+    if (!optSrcIdx)
+      return op.emitOpError("failed to flatten source memref indices");
+    auto optDstIdx = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
+                                  op.getDstIndices());
+    if (!optDstIdx)
+      return op.emitOpError("failed to flatten destination memref indices");
+
+    Type srcPtrType =
+        LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
+    Type dstPtrType =
+        LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
+    Value srcPtr = rewriter.create<LLVM::GEPOp>(
+        loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
+    
+    Value dstPtr = rewriter.create<LLVM::GEPOp>(
+        loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
+
+    rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
+        op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
+        createI32Constant(rewriter, loc, 0),
+        createI32Constant(rewriter, loc, 0));
+
+    return success();
+  }
+};
+
 namespace {
 struct ExtPackedFp8OpLowering final
     : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
@@ -1286,6 +1358,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
            MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
-           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
-                                                                      chipset);
+           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+           GlobalLoadLDSOpLowering>(converter, chipset);
 }

>From 564ebc8cd29ffe7437e1b21aacb626874e3a2453 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 27 Mar 2025 10:15:56 -0400
Subject: [PATCH 04/12] Make it work

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 47 +++++++++++--------
 1 file changed, 28 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 83e18a3d87b47..4169eb7825c95 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,7 +903,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
-struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
+struct GlobalLoadLDSOpLowering
+    : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
   GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
 
@@ -918,6 +919,10 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
     size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
     if (elemSizeInBits % 8 != 0)
       return op.emitOpError("element size must be a multiple of 8");
+
+    // TODO: instead of only transfering one element per thread, we could
+    // augment it to transfer multiple elements per thread by issuing multiple
+    // `global_load_lds` instructions.
     auto loadWidth = elemSizeInBits / 8;
 
     // TODO: add chipset support check
@@ -934,37 +939,41 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
     Value memrefSrc = op.getSrc();
     Value memrefDst = op.getDst();
 
-    // Collapse src memref with indices:
-    auto flattenIndex = [&](Value memref, MemRefType memrefType,
-                            ValueRange indices) -> std::optional<Value> {
+    // Collapse src memref with indices, returns the base pointer and linearized
+    // index.
+    auto flattenIndex =
+        [&](Value memref, MemRefType memrefType,
+            ValueRange indices) -> std::optional<std::pair<Value, Value>> {
       MemRefDescriptor memRefDescriptor(memref);
       int64_t offset = 0;
       SmallVector<int64_t, 5> strides;
       if (failed(memrefType.getStridesAndOffset(strides, offset)))
         return {};
-      return getLinearIndexI32(rewriter, loc, memRefDescriptor, indices,
-                               strides);
+      return std::make_pair(
+          memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
+                                     memrefType),
+          getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
     };
 
     // Source
-    auto optSrcIdx = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
-                                  op.getSrcIndices());
-    if (!optSrcIdx)
+    auto optSrcBuffer = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
+                                     op.getSrcIndices());
+    if (!optSrcBuffer)
       return op.emitOpError("failed to flatten source memref indices");
-    auto optDstIdx = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
-                                  op.getDstIndices());
-    if (!optDstIdx)
+    auto optDstBuffer = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
+                                     op.getDstIndices());
+    if (!optDstBuffer)
       return op.emitOpError("failed to flatten destination memref indices");
 
-    Type srcPtrType =
-        LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
-    Type dstPtrType =
-        LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
+    Type srcPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
+    Type dstPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
     Value srcPtr = rewriter.create<LLVM::GEPOp>(
-        loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
-    
+        loc, srcPtrType, elemType, optSrcBuffer->first,
+        ArrayRef<Value>({optSrcBuffer->second}));
+
     Value dstPtr = rewriter.create<LLVM::GEPOp>(
-        loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
+        loc, dstPtrType, elemType, optDstBuffer->first,
+        ArrayRef<Value>({optDstBuffer->second}));
 
     rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
         op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),

>From 92a1ef9d5fa73dbb5c486a1e7dc29290d48b2067 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 27 Mar 2025 13:04:27 -0400
Subject: [PATCH 05/12] update AMDGPU description.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 18 ++++++-----
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 30 ++++++++-----------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  5 ----
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 23 ++++++++++++++
 4 files changed, 46 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e60eaa6d11897..606cf69537ca1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -784,15 +784,19 @@ def AMDGPU_GlobalLoadLDSOp :
     Results<(outs)> {
   let summary = "MLIR wrapper for CDNA mfma instructions";
   let description = [{
-    The `amdgpu.mfma` op is an MLIR wrapper around intrinsics
-    for various `mfma` instructions in the CDNA architecture, which perform
-    multiple outer products in order to allow fast matrix multiplication.
-
-    The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
-
-    The `$src`, along with its indices, points to the memory location this thread reads from.
+    The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
+
+    Operands:
+    * `$src`: global memory memref to read from.
+    * `$srcIndices`: indices into `$src` to read from for this thread.
+    * `$dst`: LDS memory memref to write to.
+    * `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
+      number of subgroup size of elements will be written contiguously to `$dst[$dstIndices]`.
+    
     The `$dst`, along with its indices, points to the memory location the subgroup of this thread
     will write to.
+  
+    Note: only enabled for gfx942 and later.
   }];
   let assemblyFormat = [{
     $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4169eb7825c95..ffb347ade326c 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -925,23 +925,16 @@ struct GlobalLoadLDSOpLowering
     // `global_load_lds` instructions.
     auto loadWidth = elemSizeInBits / 8;
 
-    // TODO: add chipset support check
-    if (chipset.majorVersion >= 12)
-      return op.emitOpError("TODO");
+    const Chipset GlobalLoadEnabled{9, 0x4, 0x0};
+    if (chipset < GlobalLoadEnabled)
+      return op.emitOpError("chipset not supported");
 
-    // TODO: fold this into chipset check.
     // Currently only 1, 2, and 4 byte loads are supported.
     if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
-      return op.emitOpError("unsupported element size");
+      return op.emitOpError("chipset unsupported element size");
 
-    Value src = adaptor.getSrc();
-    Value dst = adaptor.getDst();
-    Value memrefSrc = op.getSrc();
-    Value memrefDst = op.getDst();
-
-    // Collapse src memref with indices, returns the base pointer and linearized
-    // index.
-    auto flattenIndex =
+    // Return pair of {base pointer, linearized index}.
+    auto getBasePtrAndLinearizedIndex =
         [&](Value memref, MemRefType memrefType,
             ValueRange indices) -> std::optional<std::pair<Value, Value>> {
       MemRefDescriptor memRefDescriptor(memref);
@@ -955,13 +948,14 @@ struct GlobalLoadLDSOpLowering
           getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
     };
 
-    // Source
-    auto optSrcBuffer = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
-                                     op.getSrcIndices());
+    auto optSrcBuffer = getBasePtrAndLinearizedIndex(
+        adaptor.getSrc(), cast<MemRefType>(op.getSrc().getType()),
+        op.getSrcIndices());
     if (!optSrcBuffer)
       return op.emitOpError("failed to flatten source memref indices");
-    auto optDstBuffer = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
-                                     op.getDstIndices());
+    auto optDstBuffer = getBasePtrAndLinearizedIndex(
+        adaptor.getDst(), cast<MemRefType>(op.getDst().getType()),
+        op.getDstIndices());
     if (!optDstBuffer)
       return op.emitOpError("failed to flatten destination memref indices");
 
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 64003e9a2dca6..efec76c91ea23 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -481,11 +481,6 @@ LogicalResult GlobalLoadLDSOp::verify() {
     return emitOpError("source memory address space must be Global");
   if (dstAddrSpace.getInt() != 3)
     return emitOpError("destination memory address space must be Workgroup");
-
-  // Check chunk size compatible with element type.
-  // TODO
-  auto dstChunkSize = dstType.getShape();
-
   return success();
 }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 8871b2ce0eadb..65515dbe30b29 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -9,6 +9,7 @@
 // test pass doesn't set up the GPU address space conversions.
 
 #gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
 
 // CHECK-LABEL: func @fat_raw_buffer_cast
 func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
@@ -461,3 +462,25 @@ func.func @sched_barrier() {
   amdgpu.sched_barrier allow = <valu|all_vmem>
   func.return
 }
+
+// CHECK-LABEL: func @global_load_to_rocdl_f32
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
+func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
+  %c0 = arith.constant 0 : i32
+  %c12 = arith.constant 12 : i32
+  %c32 = arith.constant 32 : i32
+  %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
+  // GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
+  // GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
+  // GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+  // GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
+  // GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+  // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[GLOBAL_OFFSET:.*]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
+  // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[LDS_OFFSET:.*]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+  // GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+  amdgpu.global_load %global[%c12, %c0], %alloc[%c32, %c0] : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  func.return
+}

>From a17e854c5f2f72e486f7fc53a677afa1e708c700 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 11:21:10 -0400
Subject: [PATCH 06/12] Address comments.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 16 +++---------
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 15 ++++++-----
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 25 ++++++++++---------
 3 files changed, 24 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 606cf69537ca1..70d5f5de416d1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -668,12 +668,6 @@ def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
 
-// MI300: limit to only 4 bytes per elements.
-def GlobalLoadTypes : AnyTypeOf<[F16, F32, I8, SI8, UI8, I16, I32,
-                                VectorOfLengthAndType<[2], [F16, BF16, I16]>,
-                                VectorOfLengthAndType<[2, 4], [I8, SI8, UI8]>
-                               ]>;
-
 def AMDGPU_MFMAOp :
     AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
                         Pure]>,
@@ -771,14 +765,12 @@ def AMDGPU_WMMAOp :
   let hasVerifier = 1;
 }
 
-def GlobalLoadMemRefType : MemRefOf<[GlobalLoadTypes]>;
-
-def AMDGPU_GlobalLoadLDSOp :
-    AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
+def AMDGPU_GatherToLDSOp :
+    AMDGPU_Op<"gather_to_lds", [SameVariadicOperandSize]>,
     Arguments<(ins
-                   Arg<GlobalLoadMemRefType, "buffer to read from", [MemRead]>:$src,
+                   Arg<AnyMemRef, "buffer to gather from", [MemRead]>:$src,
                    Variadic<I32>:$srcIndices,
-                   Arg<GlobalLoadMemRefType, "buffer to write to", [MemWrite]>:$dst,
+                   Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
                    Variadic<I32>:$dstIndices
                    )>,
     Results<(outs)> {
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ffb347ade326c..5ade785ea9b93 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,15 +903,15 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
-struct GlobalLoadLDSOpLowering
-    : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
-  GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
-      : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
+struct GatherToLDSOpLowering
+    : public ConvertOpToLLVMPattern<GatherToLDSOp> {
+  GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
 
   Chipset chipset;
 
   LogicalResult
-  matchAndRewrite(GlobalLoadLDSOp op, GlobalLoadLDSOpAdaptor adaptor,
+  matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
 
@@ -925,8 +925,7 @@ struct GlobalLoadLDSOpLowering
     // `global_load_lds` instructions.
     auto loadWidth = elemSizeInBits / 8;
 
-    const Chipset GlobalLoadEnabled{9, 0x4, 0x0};
-    if (chipset < GlobalLoadEnabled)
+    if (chipset < kGfx942)
       return op.emitOpError("chipset not supported");
 
     // Currently only 1, 2, and 4 byte loads are supported.
@@ -1362,5 +1361,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
            MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
            PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
-           GlobalLoadLDSOpLowering>(converter, chipset);
+           GatherToLDSOpLowering>(converter, chipset);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index efec76c91ea23..27e0cfb8044b3 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -25,7 +25,6 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/TypeSwitch.h"
-#include "llvm/IR/Metadata.h"
 
 #include <limits>
 #include <optional>
@@ -461,26 +460,28 @@ LogicalResult DPPOp::verify() {
   return success();
 }
 
-LogicalResult GlobalLoadLDSOp::verify() {
+LogicalResult GatherToLDSOp::verify() {
   MemRefType srcType = cast<MemRefType>(getSrc().getType());
   MemRefType dstType = cast<MemRefType>(getDst().getType());
 
-  if (!memref::isStaticShapeAndContiguousRowMajor(srcType) ||
-      !memref::isStaticShapeAndContiguousRowMajor(dstType))
+  if (!memref::isStaticShapeAndContiguousRowMajor(dstType))
     return emitOpError(
-        "source and destination types must have static shape  and contiguous");
+        "destination types must have static shape  and contiguous");
 
+  auto elemType = srcType.getElementType();
   // Check $src and $dst element types are the same.
-  if (srcType.getElementType() != dstType.getElementType())
+  if (elemType != dstType.getElementType())
     return emitOpError("source and destination element types must match");
 
-  // Check $src and $dst memory spaces.
-  auto srcAddrSpace = llvm::dyn_cast<IntegerAttr>(srcType.getMemorySpace());
-  auto dstAddrSpace = llvm::dyn_cast<IntegerAttr>(dstType.getMemorySpace());
-  if (!srcAddrSpace || srcAddrSpace.getInt() != 1)
-    return emitOpError("source memory address space must be Global");
-  if (dstAddrSpace.getInt() != 3)
+  // Element type sizes should be 1, 2, or 4 bytes.
+  auto elemSize = elemType.getIntOrFloatBitWidth();
+  if (elemSize != 8 && elemSize != 16 && elemSize != 32)
+    return emitOpError("source and destination element types must be 8, 16, "
+                       "or 32 bits");
+
+  if (!gpu::GPUDialect::hasWorkgroupMemoryAddressSpace(dstType))
     return emitOpError("destination memory address space must be Workgroup");
+
   return success();
 }
 

>From 4ed20068c55450581929686679197ffd999c68ab Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 14:12:42 -0400
Subject: [PATCH 07/12] update lowering

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 10 ++-
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 73 ++++++++-----------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 49 +++++++++----
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 12 +--
 4 files changed, 78 insertions(+), 66 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 70d5f5de416d1..350b184dae6a2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -769,9 +769,10 @@ def AMDGPU_GatherToLDSOp :
     AMDGPU_Op<"gather_to_lds", [SameVariadicOperandSize]>,
     Arguments<(ins
                    Arg<AnyMemRef, "buffer to gather from", [MemRead]>:$src,
-                   Variadic<I32>:$srcIndices,
+                   Variadic<Index>:$srcIndices,
                    Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
-                   Variadic<I32>:$dstIndices
+                   Variadic<Index>:$dstIndices,
+                   TypeAttr:$transferType
                    )>,
     Results<(outs)> {
   let summary = "MLIR wrapper for CDNA mfma instructions";
@@ -784,7 +785,10 @@ def AMDGPU_GatherToLDSOp :
     * `$dst`: LDS memory memref to write to.
     * `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
       number of subgroup size of elements will be written contiguously to `$dst[$dstIndices]`.
-    
+    * `$transferType`: type of the data to be transferred by each thread. This is used to determine
+      the size of the data to be transferred and the number of threads in the subgroup.
+      The transfer type must be a scalar type or a vector type with a single element type.
+
     The `$dst`, along with its indices, points to the memory location the subgroup of this thread
     will write to.
   
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5ade785ea9b93..a644c24a8d6bc 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -913,60 +913,49 @@ struct GatherToLDSOpLowering
   LogicalResult
   matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx942)
+      return op.emitOpError("chipset not supported");
+
     Location loc = op.getLoc();
 
-    auto elemType = cast<MemRefType>(op.getDst().getType()).getElementType();
-    size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
-    if (elemSizeInBits % 8 != 0)
-      return op.emitOpError("element size must be a multiple of 8");
+    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+    auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
 
     // TODO: instead of only transfering one element per thread, we could
     // augment it to transfer multiple elements per thread by issuing multiple
     // `global_load_lds` instructions.
-    auto loadWidth = elemSizeInBits / 8;
-
-    if (chipset < kGfx942)
-      return op.emitOpError("chipset not supported");
+    size_t loadWidth;
+    Type transferType = op.getTransferType();
+    if (auto transferVectorType = dyn_cast<VectorType>(transferType))
+      loadWidth = transferVectorType.getNumElements() *
+                  transferVectorType.getElementTypeBitWidth() / 8;
+    else
+      loadWidth = transferType.getIntOrFloatBitWidth() / 8;
 
     // Currently only 1, 2, and 4 byte loads are supported.
-    if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
+    if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
       return op.emitOpError("chipset unsupported element size");
 
-    // Return pair of {base pointer, linearized index}.
-    auto getBasePtrAndLinearizedIndex =
-        [&](Value memref, MemRefType memrefType,
-            ValueRange indices) -> std::optional<std::pair<Value, Value>> {
-      MemRefDescriptor memRefDescriptor(memref);
-      int64_t offset = 0;
-      SmallVector<int64_t, 5> strides;
-      if (failed(memrefType.getStridesAndOffset(strides, offset)))
-        return {};
-      return std::make_pair(
-          memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
-                                     memrefType),
-          getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
+    auto convertIndices =
+        [&](ValueRange indices) -> SmallVector<Value, 4> {
+      SmallVector<Value, 4> convertedIndices;
+      
+      for (Value index : indices) {
+        Type convertedType = getTypeConverter()->convertType(index.getType());
+        auto convertedIndex = rewriter.create<LLVM::ConstantOp>(
+            loc, convertedType,
+            rewriter.getIntegerAttr(convertedType, 0));
+        convertedIndices.push_back(convertedIndex);
+      }
+      return convertedIndices;
     };
 
-    auto optSrcBuffer = getBasePtrAndLinearizedIndex(
-        adaptor.getSrc(), cast<MemRefType>(op.getSrc().getType()),
-        op.getSrcIndices());
-    if (!optSrcBuffer)
-      return op.emitOpError("failed to flatten source memref indices");
-    auto optDstBuffer = getBasePtrAndLinearizedIndex(
-        adaptor.getDst(), cast<MemRefType>(op.getDst().getType()),
-        op.getDstIndices());
-    if (!optDstBuffer)
-      return op.emitOpError("failed to flatten destination memref indices");
-
-    Type srcPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
-    Type dstPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
-    Value srcPtr = rewriter.create<LLVM::GEPOp>(
-        loc, srcPtrType, elemType, optSrcBuffer->first,
-        ArrayRef<Value>({optSrcBuffer->second}));
-
-    Value dstPtr = rewriter.create<LLVM::GEPOp>(
-        loc, dstPtrType, elemType, optDstBuffer->first,
-        ArrayRef<Value>({optDstBuffer->second}));
+    Value srcPtr =
+        getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
+                             convertIndices(op.getSrcIndices()), rewriter);
+    Value dstPtr =
+        getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
+                             convertIndices(op.getDstIndices()), rewriter);
 
     rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
         op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 27e0cfb8044b3..605f0a92224e9 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/DerivedTypes.h"
 
 #include <limits>
 #include <optional>
@@ -113,21 +114,30 @@ LogicalResult FatRawBufferCastOp::verify() {
   return success();
 }
 
+static bool hasGlobalMemorySpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return true;
+  else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
+    return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
+  else if (auto gpuMemorySpace =
+          llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+  return false;
+}
+
+static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
+  if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
+    return intMemorySpace.getInt() == 3;
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // RawBuffer*Op
 //===----------------------------------------------------------------------===//
 template <typename T>
 static LogicalResult verifyRawBufferOp(T &op) {
   MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
-  Attribute memorySpace = bufferType.getMemorySpace();
-  bool isGlobal = false;
-  if (!memorySpace)
-    isGlobal = true;
-  else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
-    isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
-  else if (auto gpuMemorySpace =
-               llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
-    isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+  bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
 
   if (!isGlobal)
     return op.emitOpError(
@@ -473,13 +483,22 @@ LogicalResult GatherToLDSOp::verify() {
   if (elemType != dstType.getElementType())
     return emitOpError("source and destination element types must match");
 
-  // Element type sizes should be 1, 2, or 4 bytes.
-  auto elemSize = elemType.getIntOrFloatBitWidth();
-  if (elemSize != 8 && elemSize != 16 && elemSize != 32)
-    return emitOpError("source and destination element types must be 8, 16, "
-                       "or 32 bits");
+  // copy type sizes should be 1, 2, or 4 bytes.
+  auto transferType = getTransferType();
+  size_t transferSize;
+  if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
+    transferSize = vectorTransfer.getNumElements() *
+                   vectorTransfer.getElementTypeBitWidth();
+  } else {
+    transferSize = transferType.getIntOrFloatBitWidth();
+  }
+  if (transferSize != 8 && transferSize != 16 && transferSize != 32)
+    return emitOpError("Transfering type size must be 8, 16, or 32 bits");
+
+  if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
+    return emitOpError("source memory address space must be Global");
 
-  if (!gpu::GPUDialect::hasWorkgroupMemoryAddressSpace(dstType))
+  if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
     return emitOpError("destination memory address space must be Workgroup");
 
   return success();
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 65515dbe30b29..5dad6b75f9ec8 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -466,21 +466,21 @@ func.func @sched_barrier() {
 // CHECK-LABEL: func @global_load_to_rocdl_f32
 // CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
 func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
-  %c0 = arith.constant 0 : i32
-  %c12 = arith.constant 12 : i32
-  %c32 = arith.constant 32 : i32
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c32 = arith.constant 32 : index
   %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
   // GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
   // GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
   // GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
   // GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
   // GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
-  // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[GLOBAL_OFFSET:.*]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
-  // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[LDS_OFFSET:.*]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+  // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]]
+  // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]]
   // GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
   // GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
   // GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
   // GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
-  amdgpu.global_load %global[%c12, %c0], %alloc[%c32, %c0] : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32} : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
   func.return
 }

>From 81bb8bc2aa6b75aabe6a0d50436cabca90287691 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 14:37:42 -0400
Subject: [PATCH 08/12] update test files.

---
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 23 ----------------
 .../Conversion/AMDGPUToROCDL/load_lds.mlir    | 26 +++++++++++++++++++
 2 files changed, 26 insertions(+), 23 deletions(-)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir

diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 5dad6b75f9ec8..8871b2ce0eadb 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -9,7 +9,6 @@
 // test pass doesn't set up the GPU address space conversions.
 
 #gpu_global_addrspace = 1
-#gpu_lds_addrspace = 3
 
 // CHECK-LABEL: func @fat_raw_buffer_cast
 func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
@@ -462,25 +461,3 @@ func.func @sched_barrier() {
   amdgpu.sched_barrier allow = <valu|all_vmem>
   func.return
 }
-
-// CHECK-LABEL: func @global_load_to_rocdl_f32
-// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
-func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
-  %c0 = arith.constant 0 : index
-  %c12 = arith.constant 12 : index
-  %c32 = arith.constant 32 : index
-  %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
-  // GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
-  // GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
-  // GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
-  // GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
-  // GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
-  // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]]
-  // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]]
-  // GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
-  // GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
-  // GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-  // GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
-  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32} : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
-  func.return
-}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
new file mode 100644
index 0000000000000..507928eaf0d78
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s
+
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+
+// CHECK-LABEL: func @global_load_to_rocdl_f32
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
+func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c32 = arith.constant 32 : index
+  %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]]
+  // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]]
+  // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32} : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  func.return
+}

>From 73629f4f09c581e027943804270c582621bc7192 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 16:09:58 -0400
Subject: [PATCH 09/12] linting

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 11 ++++-------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp        |  2 +-
 2 files changed, 5 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index a644c24a8d6bc..5e8cbf262f0eb 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,8 +903,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
-struct GatherToLDSOpLowering
-    : public ConvertOpToLLVMPattern<GatherToLDSOp> {
+struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
   GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
 
@@ -936,15 +935,13 @@ struct GatherToLDSOpLowering
     if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
       return op.emitOpError("chipset unsupported element size");
 
-    auto convertIndices =
-        [&](ValueRange indices) -> SmallVector<Value, 4> {
+    auto convertIndices = [&](ValueRange indices) -> SmallVector<Value, 4> {
       SmallVector<Value, 4> convertedIndices;
-      
+
       for (Value index : indices) {
         Type convertedType = getTypeConverter()->convertType(index.getType());
         auto convertedIndex = rewriter.create<LLVM::ConstantOp>(
-            loc, convertedType,
-            rewriter.getIntegerAttr(convertedType, 0));
+            loc, convertedType, rewriter.getIntegerAttr(convertedType, 0));
         convertedIndices.push_back(convertedIndex);
       }
       return convertedIndices;
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 605f0a92224e9..00d67f58aee1c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -120,7 +120,7 @@ static bool hasGlobalMemorySpace(Attribute memorySpace) {
   else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
     return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
   else if (auto gpuMemorySpace =
-          llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+               llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
     return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
   return false;
 }

>From b48370161416b7ddce8734664162d282ca4c2bcb Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 20:52:19 -0400
Subject: [PATCH 10/12] Update tests.

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  22 +--
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |   7 +-
 .../Conversion/AMDGPUToROCDL/load_lds.mlir    | 137 +++++++++++++++++-
 3 files changed, 139 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5e8cbf262f0eb..7efe70fd16a04 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -935,24 +935,10 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
     if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
       return op.emitOpError("chipset unsupported element size");
 
-    auto convertIndices = [&](ValueRange indices) -> SmallVector<Value, 4> {
-      SmallVector<Value, 4> convertedIndices;
-
-      for (Value index : indices) {
-        Type convertedType = getTypeConverter()->convertType(index.getType());
-        auto convertedIndex = rewriter.create<LLVM::ConstantOp>(
-            loc, convertedType, rewriter.getIntegerAttr(convertedType, 0));
-        convertedIndices.push_back(convertedIndex);
-      }
-      return convertedIndices;
-    };
-
-    Value srcPtr =
-        getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
-                             convertIndices(op.getSrcIndices()), rewriter);
-    Value dstPtr =
-        getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
-                             convertIndices(op.getDstIndices()), rewriter);
+    Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
+                                        (adaptor.getSrcIndices()), rewriter);
+    Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
+                                        (adaptor.getDstIndices()), rewriter);
 
     rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
         op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 00d67f58aee1c..6309dac761e4f 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -117,10 +117,9 @@ LogicalResult FatRawBufferCastOp::verify() {
 static bool hasGlobalMemorySpace(Attribute memorySpace) {
   if (!memorySpace)
     return true;
-  else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
+  if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
     return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
-  else if (auto gpuMemorySpace =
-               llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+  if (auto gpuMemorySpace = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
     return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
   return false;
 }
@@ -128,6 +127,8 @@ static bool hasGlobalMemorySpace(Attribute memorySpace) {
 static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
   if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
     return intMemorySpace.getInt() == 3;
+  if (auto gpuMemorySpace = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
   return false;
 }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
index 507928eaf0d78..c887f6504172a 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -10,17 +10,142 @@ func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_add
   %c12 = arith.constant 12 : index
   %c32 = arith.constant 32 : index
   %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
-  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
-  // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
-  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+  // CHECK: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+  // CHECK: %[[C32:.*]] = arith.constant 32 : index
+  // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+  // CHECK: %[[ALLOC:.*]] = memref.alloc()
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
+  // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
+  
+  // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+  // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
+  // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+
+  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
+  // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32}
+    : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  func.return
+}
+
+// CHECK-LABEL: func @global_load_to_rocdl_i8
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi8, 1>)
+func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrspace>) {
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+  // CHECK: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+  // CHECK: %[[C32:.*]] = arith.constant 32 : index
+  // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+  // CHECK: %[[ALLOC:.*]] = memref.alloc()
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
+  // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
+  
+  // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+  // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
+  // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+
+  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
+  // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C1]], %[[C0]], %[[C0_2]]
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c32 = arith.constant 32 : index
+  %alloc = memref.alloc() : memref<64x64xi8, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = i8}
+    : memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace>
+  func.return
+}
+
+// CHECK-LABEL: func @global_load_to_rocdl_vec
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi16, 1>)
+func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_addrspace>) {
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+  // CHECK: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+  // CHECK: %[[C32:.*]] = arith.constant 32 : index
+  // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+  // CHECK: %[[ALLOC:.*]] = memref.alloc()
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
   // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
-  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]]
+  
+  // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+  // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
   // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
-  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]]
+
+  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
   // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
-  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32} : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c32 = arith.constant 32 : index
+  %alloc = memref.alloc() : memref<64x128xi16, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = vector<2 x i16>}
+    : memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace>
   func.return
 }
+
+
+// CHECK-LABEL: func @global_load_to_rocdl_dynamic_indices
+// CHECK-SAME: (%[[ARG0:.*]]: memref<512xi32, 1>, %[[SRC_IDX:.*]]: index, %[[DST_IDX:.*]]: index)
+func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_global_addrspace>, %src_idx : index, %dst_idx : index) {
+  // CHECK: %[[DSTIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[DST_IDX]]
+  // CHECK: %[[SRCIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC_IDX]]
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  // CHECK: %[[ALLOC:.*]] = memref.alloc()
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
+  // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]]
+  // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX_CAST]]]
+  // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+  %alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0] {transferType = i32}
+    : memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
+  func.return
+}
\ No newline at end of file

>From 1a40d6c20698052a569c8f1a49425ed6acb3834c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 20:59:25 -0400
Subject: [PATCH 11/12] update again

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td    |  2 +-
 mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir | 16 ++++++++--------
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 350b184dae6a2..52753e2faf901 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -795,7 +795,7 @@ def AMDGPU_GatherToLDSOp :
     Note: only enabled for gfx942 and later.
   }];
   let assemblyFormat = [{
-    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
+    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
index c887f6504172a..160e5b203ed1f 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -39,8 +39,8 @@ func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_add
   // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
-  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32}
-    : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+    : f32, memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
   func.return
 }
 
@@ -80,8 +80,8 @@ func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrs
   %c12 = arith.constant 12 : index
   %c32 = arith.constant 32 : index
   %alloc = memref.alloc() : memref<64x64xi8, #gpu_lds_addrspace>
-  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = i8}
-    : memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+    : i8, memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace>
   func.return
 }
 
@@ -121,8 +121,8 @@ func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_add
   %c12 = arith.constant 12 : index
   %c32 = arith.constant 32 : index
   %alloc = memref.alloc() : memref<64x128xi16, #gpu_lds_addrspace>
-  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = vector<2 x i16>}
-    : memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+    : vector<2 x i16>, memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace>
   func.return
 }
 
@@ -145,7 +145,7 @@ func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_g
   // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
   %alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
   %c0 = arith.constant 0 : index
-  amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0] {transferType = i32}
-    : memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0]
+    : i32, memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
   func.return
 }
\ No newline at end of file

>From d68db39da2153b53c6f06c3521220b1a5090ee07 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 7 Apr 2025 13:59:06 -0400
Subject: [PATCH 12/12] Final touch.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td       | 3 ++-
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 7 ++++---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp        | 8 ++++----
 mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir    | 2 +-
 4 files changed, 11 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 52753e2faf901..fc9d2e66ab468 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -784,7 +784,8 @@ def AMDGPU_GatherToLDSOp :
     * `$srcIndices`: indices into `$src` to read from for this thread.
     * `$dst`: LDS memory memref to write to.
     * `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
-      number of subgroup size of elements will be written contiguously to `$dst[$dstIndices]`.
+      The elements gathered by the subgroup will be written in order of lane ID will be written
+      into contiguously starting at `$dst[$dstIndices]`.
     * `$transferType`: type of the data to be transferred by each thread. This is used to determine
       the size of the data to be transferred and the number of threads in the subgroup.
       The transfer type must be a scalar type or a vector type with a single element type.
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 7efe70fd16a04..ed602242a3023 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -925,11 +925,12 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
     // `global_load_lds` instructions.
     size_t loadWidth;
     Type transferType = op.getTransferType();
-    if (auto transferVectorType = dyn_cast<VectorType>(transferType))
+    if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
       loadWidth = transferVectorType.getNumElements() *
-                  transferVectorType.getElementTypeBitWidth() / 8;
-    else
+                  (transferVectorType.getElementTypeBitWidth() / 8);
+    } else {
       loadWidth = transferType.getIntOrFloatBitWidth() / 8;
+    }
 
     // Currently only 1, 2, and 4 byte loads are supported.
     if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 6309dac761e4f..dd20babd76c3a 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -117,17 +117,17 @@ LogicalResult FatRawBufferCastOp::verify() {
 static bool hasGlobalMemorySpace(Attribute memorySpace) {
   if (!memorySpace)
     return true;
-  if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
+  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
     return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
-  if (auto gpuMemorySpace = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
     return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
   return false;
 }
 
 static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
-  if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
+  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
     return intMemorySpace.getInt() == 3;
-  if (auto gpuMemorySpace = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
     return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
   return false;
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
index 160e5b203ed1f..8da1f55ac338a 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -148,4 +148,4 @@ func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_g
   amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0]
     : i32, memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
   func.return
-}
\ No newline at end of file
+}



More information about the Mlir-commits mailing list