[Mlir-commits] [mlir] [MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU (PR #133498)

Alan Li llvmlistbot at llvm.org
Fri Mar 28 13:56:35 PDT 2025


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/133498

>From ba384661402929671a0de82ffd813548af38f3de Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 25 Mar 2025 14:50:45 -0400
Subject: [PATCH 1/5] make it compiable.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 32 +++++++++++++++++++
 1 file changed, 32 insertions(+)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index c0b3e5540b1df..53323be7d6c5c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -668,6 +668,12 @@ def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
 
+// MI300: limit to only 4 bytes per elements.
+def GlobalLoadTypes : AnyTypeOf<[F16, F32, I8, SI8, UI8, I16, I32,
+                                VectorOfLengthAndType<[2], [F16, BF16, I16]>,
+                                VectorOfLengthAndType<[2, 4], [I8, SI8, UI8]>
+                               ]>;
+
 def AMDGPU_MFMAOp :
     AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
                         Pure]>,
@@ -765,4 +771,30 @@ def AMDGPU_WMMAOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_GlobalLoadLDSOp :
+    AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
+    Arguments<(ins
+                   Arg<AnyMemRef, "buffer to read from", [MemRead]>:$src,
+                   Variadic<I32>:$srcIndices,
+                   Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
+                   Variadic<I32>:$dstIndices
+                   )>,
+    Results<(outs)> {
+  let summary = "MLIR wrapper for CDNA mfma instructions";
+  let description = [{
+    The `amdgpu.mfma` op is an MLIR wrapper around intrinsics
+    for various `mfma` instructions in the CDNA architecture, which perform
+    multiple outer products in order to allow fast matrix multiplication.
+
+    The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
+
+    The 
+
+  }];
+  let assemblyFormat = [{
+    $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` `,` type($src) `,` type($dst)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // AMDGPU

>From 8e46990cc375af4a2dbf908da1cb79a3a93c138c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 25 Mar 2025 15:38:00 -0400
Subject: [PATCH 2/5] Adding a verifier.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  2 +-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 30 +++++++++++++++++++
 2 files changed, 31 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 53323be7d6c5c..b7cfed1f18c6c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -792,7 +792,7 @@ def AMDGPU_GlobalLoadLDSOp :
 
   }];
   let assemblyFormat = [{
-    $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` `,` type($src) `,` type($dst)
+    $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2b2a167b90c82..64003e9a2dca6 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -24,6 +25,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Metadata.h"
 
 #include <limits>
 #include <optional>
@@ -459,6 +461,34 @@ LogicalResult DPPOp::verify() {
   return success();
 }
 
+LogicalResult GlobalLoadLDSOp::verify() {
+  MemRefType srcType = cast<MemRefType>(getSrc().getType());
+  MemRefType dstType = cast<MemRefType>(getDst().getType());
+
+  if (!memref::isStaticShapeAndContiguousRowMajor(srcType) ||
+      !memref::isStaticShapeAndContiguousRowMajor(dstType))
+    return emitOpError(
+        "source and destination types must have static shape  and contiguous");
+
+  // Check $src and $dst element types are the same.
+  if (srcType.getElementType() != dstType.getElementType())
+    return emitOpError("source and destination element types must match");
+
+  // Check $src and $dst memory spaces.
+  auto srcAddrSpace = llvm::dyn_cast<IntegerAttr>(srcType.getMemorySpace());
+  auto dstAddrSpace = llvm::dyn_cast<IntegerAttr>(dstType.getMemorySpace());
+  if (!srcAddrSpace || srcAddrSpace.getInt() != 1)
+    return emitOpError("source memory address space must be Global");
+  if (dstAddrSpace.getInt() != 3)
+    return emitOpError("destination memory address space must be Workgroup");
+
+  // Check chunk size compatible with element type.
+  // TODO
+  auto dstChunkSize = dstType.getShape();
+
+  return success();
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES

>From c072e660b1147a4f0b439f41735edc3df2f07bc5 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 26 Mar 2025 22:30:13 -0400
Subject: [PATCH 3/5] checkpoint

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 13 ++--
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 76 ++++++++++++++++++-
 2 files changed, 82 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index b7cfed1f18c6c..e60eaa6d11897 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -771,12 +771,14 @@ def AMDGPU_WMMAOp :
   let hasVerifier = 1;
 }
 
+def GlobalLoadMemRefType : MemRefOf<[GlobalLoadTypes]>;
+
 def AMDGPU_GlobalLoadLDSOp :
     AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
     Arguments<(ins
-                   Arg<AnyMemRef, "buffer to read from", [MemRead]>:$src,
+                   Arg<GlobalLoadMemRefType, "buffer to read from", [MemRead]>:$src,
                    Variadic<I32>:$srcIndices,
-                   Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
+                   Arg<GlobalLoadMemRefType, "buffer to write to", [MemWrite]>:$dst,
                    Variadic<I32>:$dstIndices
                    )>,
     Results<(outs)> {
@@ -788,11 +790,12 @@ def AMDGPU_GlobalLoadLDSOp :
 
     The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
 
-    The 
-
+    The `$src`, along with its indices, points to the memory location this thread reads from.
+    The `$dst`, along with its indices, points to the memory location the subgroup of this thread
+    will write to.
   }];
   let assemblyFormat = [{
-    $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
+    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3acd470cff7f5..83e18a3d87b47 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,6 +903,78 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
+  GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(GlobalLoadLDSOp op, GlobalLoadLDSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
+    auto elemType = cast<MemRefType>(op.getDst().getType()).getElementType();
+    size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
+    if (elemSizeInBits % 8 != 0)
+      return op.emitOpError("element size must be a multiple of 8");
+    auto loadWidth = elemSizeInBits / 8;
+
+    // TODO: add chipset support check
+    if (chipset.majorVersion >= 12)
+      return op.emitOpError("TODO");
+
+    // TODO: fold this into chipset check.
+    // Currently only 1, 2, and 4 byte loads are supported.
+    if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
+      return op.emitOpError("unsupported element size");
+
+    Value src = adaptor.getSrc();
+    Value dst = adaptor.getDst();
+    Value memrefSrc = op.getSrc();
+    Value memrefDst = op.getDst();
+
+    // Collapse src memref with indices:
+    auto flattenIndex = [&](Value memref, MemRefType memrefType,
+                            ValueRange indices) -> std::optional<Value> {
+      MemRefDescriptor memRefDescriptor(memref);
+      int64_t offset = 0;
+      SmallVector<int64_t, 5> strides;
+      if (failed(memrefType.getStridesAndOffset(strides, offset)))
+        return {};
+      return getLinearIndexI32(rewriter, loc, memRefDescriptor, indices,
+                               strides);
+    };
+
+    // Source
+    auto optSrcIdx = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
+                                  op.getSrcIndices());
+    if (!optSrcIdx)
+      return op.emitOpError("failed to flatten source memref indices");
+    auto optDstIdx = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
+                                  op.getDstIndices());
+    if (!optDstIdx)
+      return op.emitOpError("failed to flatten destination memref indices");
+
+    Type srcPtrType =
+        LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
+    Type dstPtrType =
+        LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
+    Value srcPtr = rewriter.create<LLVM::GEPOp>(
+        loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
+    
+    Value dstPtr = rewriter.create<LLVM::GEPOp>(
+        loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
+
+    rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
+        op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
+        createI32Constant(rewriter, loc, 0),
+        createI32Constant(rewriter, loc, 0));
+
+    return success();
+  }
+};
+
 namespace {
 struct ExtPackedFp8OpLowering final
     : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
@@ -1286,6 +1358,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
            MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
-           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
-                                                                      chipset);
+           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+           GlobalLoadLDSOpLowering>(converter, chipset);
 }

>From 564ebc8cd29ffe7437e1b21aacb626874e3a2453 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 27 Mar 2025 10:15:56 -0400
Subject: [PATCH 4/5] Make it work

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 47 +++++++++++--------
 1 file changed, 28 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 83e18a3d87b47..4169eb7825c95 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,7 +903,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
-struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
+struct GlobalLoadLDSOpLowering
+    : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
   GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
 
@@ -918,6 +919,10 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
     size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
     if (elemSizeInBits % 8 != 0)
       return op.emitOpError("element size must be a multiple of 8");
+
+    // TODO: instead of only transfering one element per thread, we could
+    // augment it to transfer multiple elements per thread by issuing multiple
+    // `global_load_lds` instructions.
     auto loadWidth = elemSizeInBits / 8;
 
     // TODO: add chipset support check
@@ -934,37 +939,41 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
     Value memrefSrc = op.getSrc();
     Value memrefDst = op.getDst();
 
-    // Collapse src memref with indices:
-    auto flattenIndex = [&](Value memref, MemRefType memrefType,
-                            ValueRange indices) -> std::optional<Value> {
+    // Collapse src memref with indices, returns the base pointer and linearized
+    // index.
+    auto flattenIndex =
+        [&](Value memref, MemRefType memrefType,
+            ValueRange indices) -> std::optional<std::pair<Value, Value>> {
       MemRefDescriptor memRefDescriptor(memref);
       int64_t offset = 0;
       SmallVector<int64_t, 5> strides;
       if (failed(memrefType.getStridesAndOffset(strides, offset)))
         return {};
-      return getLinearIndexI32(rewriter, loc, memRefDescriptor, indices,
-                               strides);
+      return std::make_pair(
+          memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
+                                     memrefType),
+          getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
     };
 
     // Source
-    auto optSrcIdx = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
-                                  op.getSrcIndices());
-    if (!optSrcIdx)
+    auto optSrcBuffer = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
+                                     op.getSrcIndices());
+    if (!optSrcBuffer)
       return op.emitOpError("failed to flatten source memref indices");
-    auto optDstIdx = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
-                                  op.getDstIndices());
-    if (!optDstIdx)
+    auto optDstBuffer = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
+                                     op.getDstIndices());
+    if (!optDstBuffer)
       return op.emitOpError("failed to flatten destination memref indices");
 
-    Type srcPtrType =
-        LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
-    Type dstPtrType =
-        LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
+    Type srcPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
+    Type dstPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
     Value srcPtr = rewriter.create<LLVM::GEPOp>(
-        loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
-    
+        loc, srcPtrType, elemType, optSrcBuffer->first,
+        ArrayRef<Value>({optSrcBuffer->second}));
+
     Value dstPtr = rewriter.create<LLVM::GEPOp>(
-        loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
+        loc, dstPtrType, elemType, optDstBuffer->first,
+        ArrayRef<Value>({optDstBuffer->second}));
 
     rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
         op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),

>From 92a1ef9d5fa73dbb5c486a1e7dc29290d48b2067 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 27 Mar 2025 13:04:27 -0400
Subject: [PATCH 5/5] update AMDGPU description.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 18 ++++++-----
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 30 ++++++++-----------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  5 ----
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 23 ++++++++++++++
 4 files changed, 46 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e60eaa6d11897..606cf69537ca1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -784,15 +784,19 @@ def AMDGPU_GlobalLoadLDSOp :
     Results<(outs)> {
   let summary = "MLIR wrapper for CDNA mfma instructions";
   let description = [{
-    The `amdgpu.mfma` op is an MLIR wrapper around intrinsics
-    for various `mfma` instructions in the CDNA architecture, which perform
-    multiple outer products in order to allow fast matrix multiplication.
-
-    The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
-
-    The `$src`, along with its indices, points to the memory location this thread reads from.
+    The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
+
+    Operands:
+    * `$src`: global memory memref to read from.
+    * `$srcIndices`: indices into `$src` to read from for this thread.
+    * `$dst`: LDS memory memref to write to.
+    * `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
+      number of subgroup size of elements will be written contiguously to `$dst[$dstIndices]`.
+    
     The `$dst`, along with its indices, points to the memory location the subgroup of this thread
     will write to.
+  
+    Note: only enabled for gfx942 and later.
   }];
   let assemblyFormat = [{
     $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4169eb7825c95..ffb347ade326c 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -925,23 +925,16 @@ struct GlobalLoadLDSOpLowering
     // `global_load_lds` instructions.
     auto loadWidth = elemSizeInBits / 8;
 
-    // TODO: add chipset support check
-    if (chipset.majorVersion >= 12)
-      return op.emitOpError("TODO");
+    const Chipset GlobalLoadEnabled{9, 0x4, 0x0};
+    if (chipset < GlobalLoadEnabled)
+      return op.emitOpError("chipset not supported");
 
-    // TODO: fold this into chipset check.
     // Currently only 1, 2, and 4 byte loads are supported.
     if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
-      return op.emitOpError("unsupported element size");
+      return op.emitOpError("chipset unsupported element size");
 
-    Value src = adaptor.getSrc();
-    Value dst = adaptor.getDst();
-    Value memrefSrc = op.getSrc();
-    Value memrefDst = op.getDst();
-
-    // Collapse src memref with indices, returns the base pointer and linearized
-    // index.
-    auto flattenIndex =
+    // Return pair of {base pointer, linearized index}.
+    auto getBasePtrAndLinearizedIndex =
         [&](Value memref, MemRefType memrefType,
             ValueRange indices) -> std::optional<std::pair<Value, Value>> {
       MemRefDescriptor memRefDescriptor(memref);
@@ -955,13 +948,14 @@ struct GlobalLoadLDSOpLowering
           getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
     };
 
-    // Source
-    auto optSrcBuffer = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
-                                     op.getSrcIndices());
+    auto optSrcBuffer = getBasePtrAndLinearizedIndex(
+        adaptor.getSrc(), cast<MemRefType>(op.getSrc().getType()),
+        op.getSrcIndices());
     if (!optSrcBuffer)
       return op.emitOpError("failed to flatten source memref indices");
-    auto optDstBuffer = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
-                                     op.getDstIndices());
+    auto optDstBuffer = getBasePtrAndLinearizedIndex(
+        adaptor.getDst(), cast<MemRefType>(op.getDst().getType()),
+        op.getDstIndices());
     if (!optDstBuffer)
       return op.emitOpError("failed to flatten destination memref indices");
 
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 64003e9a2dca6..efec76c91ea23 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -481,11 +481,6 @@ LogicalResult GlobalLoadLDSOp::verify() {
     return emitOpError("source memory address space must be Global");
   if (dstAddrSpace.getInt() != 3)
     return emitOpError("destination memory address space must be Workgroup");
-
-  // Check chunk size compatible with element type.
-  // TODO
-  auto dstChunkSize = dstType.getShape();
-
   return success();
 }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 8871b2ce0eadb..65515dbe30b29 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -9,6 +9,7 @@
 // test pass doesn't set up the GPU address space conversions.
 
 #gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
 
 // CHECK-LABEL: func @fat_raw_buffer_cast
 func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
@@ -461,3 +462,25 @@ func.func @sched_barrier() {
   amdgpu.sched_barrier allow = <valu|all_vmem>
   func.return
 }
+
+// CHECK-LABEL: func @global_load_to_rocdl_f32
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
+func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
+  %c0 = arith.constant 0 : i32
+  %c12 = arith.constant 12 : i32
+  %c32 = arith.constant 32 : i32
+  %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
+  // GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
+  // GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
+  // GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+  // GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
+  // GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+  // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[GLOBAL_OFFSET:.*]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
+  // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[LDS_OFFSET:.*]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+  // GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+  amdgpu.global_load %global[%c12, %c0], %alloc[%c32, %c0] : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  func.return
+}



More information about the Mlir-commits mailing list