[Mlir-commits] [mlir] [MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU (PR #133498)
Alan Li
llvmlistbot at llvm.org
Fri Mar 28 13:56:35 PDT 2025
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/133498
>From ba384661402929671a0de82ffd813548af38f3de Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 25 Mar 2025 14:50:45 -0400
Subject: [PATCH 1/5] make it compiable.
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 32 +++++++++++++++++++
1 file changed, 32 insertions(+)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index c0b3e5540b1df..53323be7d6c5c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -668,6 +668,12 @@ def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
+// MI300: limit to only 4 bytes per elements.
+def GlobalLoadTypes : AnyTypeOf<[F16, F32, I8, SI8, UI8, I16, I32,
+ VectorOfLengthAndType<[2], [F16, BF16, I16]>,
+ VectorOfLengthAndType<[2, 4], [I8, SI8, UI8]>
+ ]>;
+
def AMDGPU_MFMAOp :
AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
@@ -765,4 +771,30 @@ def AMDGPU_WMMAOp :
let hasVerifier = 1;
}
+def AMDGPU_GlobalLoadLDSOp :
+ AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
+ Arguments<(ins
+ Arg<AnyMemRef, "buffer to read from", [MemRead]>:$src,
+ Variadic<I32>:$srcIndices,
+ Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
+ Variadic<I32>:$dstIndices
+ )>,
+ Results<(outs)> {
+ let summary = "MLIR wrapper for CDNA mfma instructions";
+ let description = [{
+ The `amdgpu.mfma` op is an MLIR wrapper around intrinsics
+ for various `mfma` instructions in the CDNA architecture, which perform
+ multiple outer products in order to allow fast matrix multiplication.
+
+ The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
+
+ The
+
+ }];
+ let assemblyFormat = [{
+ $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` `,` type($src) `,` type($dst)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // AMDGPU
>From 8e46990cc375af4a2dbf908da1cb79a3a93c138c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 25 Mar 2025 15:38:00 -0400
Subject: [PATCH 2/5] Adding a verifier.
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 2 +-
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 30 +++++++++++++++++++
2 files changed, 31 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 53323be7d6c5c..b7cfed1f18c6c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -792,7 +792,7 @@ def AMDGPU_GlobalLoadLDSOp :
}];
let assemblyFormat = [{
- $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` `,` type($src) `,` type($dst)
+ $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
}];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2b2a167b90c82..64003e9a2dca6 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
@@ -24,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Metadata.h"
#include <limits>
#include <optional>
@@ -459,6 +461,34 @@ LogicalResult DPPOp::verify() {
return success();
}
+LogicalResult GlobalLoadLDSOp::verify() {
+ MemRefType srcType = cast<MemRefType>(getSrc().getType());
+ MemRefType dstType = cast<MemRefType>(getDst().getType());
+
+ if (!memref::isStaticShapeAndContiguousRowMajor(srcType) ||
+ !memref::isStaticShapeAndContiguousRowMajor(dstType))
+ return emitOpError(
+ "source and destination types must have static shape and contiguous");
+
+ // Check $src and $dst element types are the same.
+ if (srcType.getElementType() != dstType.getElementType())
+ return emitOpError("source and destination element types must match");
+
+ // Check $src and $dst memory spaces.
+ auto srcAddrSpace = llvm::dyn_cast<IntegerAttr>(srcType.getMemorySpace());
+ auto dstAddrSpace = llvm::dyn_cast<IntegerAttr>(dstType.getMemorySpace());
+ if (!srcAddrSpace || srcAddrSpace.getInt() != 1)
+ return emitOpError("source memory address space must be Global");
+ if (dstAddrSpace.getInt() != 3)
+ return emitOpError("destination memory address space must be Workgroup");
+
+ // Check chunk size compatible with element type.
+ // TODO
+ auto dstChunkSize = dstType.getShape();
+
+ return success();
+}
+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
>From c072e660b1147a4f0b439f41735edc3df2f07bc5 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 26 Mar 2025 22:30:13 -0400
Subject: [PATCH 3/5] checkpoint
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 13 ++--
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 76 ++++++++++++++++++-
2 files changed, 82 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index b7cfed1f18c6c..e60eaa6d11897 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -771,12 +771,14 @@ def AMDGPU_WMMAOp :
let hasVerifier = 1;
}
+def GlobalLoadMemRefType : MemRefOf<[GlobalLoadTypes]>;
+
def AMDGPU_GlobalLoadLDSOp :
AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
Arguments<(ins
- Arg<AnyMemRef, "buffer to read from", [MemRead]>:$src,
+ Arg<GlobalLoadMemRefType, "buffer to read from", [MemRead]>:$src,
Variadic<I32>:$srcIndices,
- Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
+ Arg<GlobalLoadMemRefType, "buffer to write to", [MemWrite]>:$dst,
Variadic<I32>:$dstIndices
)>,
Results<(outs)> {
@@ -788,11 +790,12 @@ def AMDGPU_GlobalLoadLDSOp :
The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
- The
-
+ The `$src`, along with its indices, points to the memory location this thread reads from.
+ The `$dst`, along with its indices, points to the memory location the subgroup of this thread
+ will write to.
}];
let assemblyFormat = [{
- $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
+ $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
}];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3acd470cff7f5..83e18a3d87b47 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,6 +903,78 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
+ GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(GlobalLoadLDSOp op, GlobalLoadLDSOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+
+ auto elemType = cast<MemRefType>(op.getDst().getType()).getElementType();
+ size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
+ if (elemSizeInBits % 8 != 0)
+ return op.emitOpError("element size must be a multiple of 8");
+ auto loadWidth = elemSizeInBits / 8;
+
+ // TODO: add chipset support check
+ if (chipset.majorVersion >= 12)
+ return op.emitOpError("TODO");
+
+ // TODO: fold this into chipset check.
+ // Currently only 1, 2, and 4 byte loads are supported.
+ if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
+ return op.emitOpError("unsupported element size");
+
+ Value src = adaptor.getSrc();
+ Value dst = adaptor.getDst();
+ Value memrefSrc = op.getSrc();
+ Value memrefDst = op.getDst();
+
+ // Collapse src memref with indices:
+ auto flattenIndex = [&](Value memref, MemRefType memrefType,
+ ValueRange indices) -> std::optional<Value> {
+ MemRefDescriptor memRefDescriptor(memref);
+ int64_t offset = 0;
+ SmallVector<int64_t, 5> strides;
+ if (failed(memrefType.getStridesAndOffset(strides, offset)))
+ return {};
+ return getLinearIndexI32(rewriter, loc, memRefDescriptor, indices,
+ strides);
+ };
+
+ // Source
+ auto optSrcIdx = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
+ op.getSrcIndices());
+ if (!optSrcIdx)
+ return op.emitOpError("failed to flatten source memref indices");
+ auto optDstIdx = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
+ op.getDstIndices());
+ if (!optDstIdx)
+ return op.emitOpError("failed to flatten destination memref indices");
+
+ Type srcPtrType =
+ LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
+ Type dstPtrType =
+ LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
+ Value srcPtr = rewriter.create<LLVM::GEPOp>(
+ loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
+
+ Value dstPtr = rewriter.create<LLVM::GEPOp>(
+ loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
+
+ rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
+ op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
+ createI32Constant(rewriter, loc, 0),
+ createI32Constant(rewriter, loc, 0));
+
+ return success();
+ }
+};
+
namespace {
struct ExtPackedFp8OpLowering final
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
@@ -1286,6 +1358,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
- chipset);
+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+ GlobalLoadLDSOpLowering>(converter, chipset);
}
>From 564ebc8cd29ffe7437e1b21aacb626874e3a2453 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 27 Mar 2025 10:15:56 -0400
Subject: [PATCH 4/5] Make it work
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 47 +++++++++++--------
1 file changed, 28 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 83e18a3d87b47..4169eb7825c95 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,7 +903,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
-struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
+struct GlobalLoadLDSOpLowering
+ : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
@@ -918,6 +919,10 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
if (elemSizeInBits % 8 != 0)
return op.emitOpError("element size must be a multiple of 8");
+
+ // TODO: instead of only transfering one element per thread, we could
+ // augment it to transfer multiple elements per thread by issuing multiple
+ // `global_load_lds` instructions.
auto loadWidth = elemSizeInBits / 8;
// TODO: add chipset support check
@@ -934,37 +939,41 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
Value memrefSrc = op.getSrc();
Value memrefDst = op.getDst();
- // Collapse src memref with indices:
- auto flattenIndex = [&](Value memref, MemRefType memrefType,
- ValueRange indices) -> std::optional<Value> {
+ // Collapse src memref with indices, returns the base pointer and linearized
+ // index.
+ auto flattenIndex =
+ [&](Value memref, MemRefType memrefType,
+ ValueRange indices) -> std::optional<std::pair<Value, Value>> {
MemRefDescriptor memRefDescriptor(memref);
int64_t offset = 0;
SmallVector<int64_t, 5> strides;
if (failed(memrefType.getStridesAndOffset(strides, offset)))
return {};
- return getLinearIndexI32(rewriter, loc, memRefDescriptor, indices,
- strides);
+ return std::make_pair(
+ memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
+ memrefType),
+ getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
};
// Source
- auto optSrcIdx = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
- op.getSrcIndices());
- if (!optSrcIdx)
+ auto optSrcBuffer = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
+ op.getSrcIndices());
+ if (!optSrcBuffer)
return op.emitOpError("failed to flatten source memref indices");
- auto optDstIdx = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
- op.getDstIndices());
- if (!optDstIdx)
+ auto optDstBuffer = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
+ op.getDstIndices());
+ if (!optDstBuffer)
return op.emitOpError("failed to flatten destination memref indices");
- Type srcPtrType =
- LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
- Type dstPtrType =
- LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
+ Type srcPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
+ Type dstPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
Value srcPtr = rewriter.create<LLVM::GEPOp>(
- loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
-
+ loc, srcPtrType, elemType, optSrcBuffer->first,
+ ArrayRef<Value>({optSrcBuffer->second}));
+
Value dstPtr = rewriter.create<LLVM::GEPOp>(
- loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
+ loc, dstPtrType, elemType, optDstBuffer->first,
+ ArrayRef<Value>({optDstBuffer->second}));
rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
>From 92a1ef9d5fa73dbb5c486a1e7dc29290d48b2067 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 27 Mar 2025 13:04:27 -0400
Subject: [PATCH 5/5] update AMDGPU description.
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 18 ++++++-----
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 30 ++++++++-----------
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 5 ----
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 23 ++++++++++++++
4 files changed, 46 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e60eaa6d11897..606cf69537ca1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -784,15 +784,19 @@ def AMDGPU_GlobalLoadLDSOp :
Results<(outs)> {
let summary = "MLIR wrapper for CDNA mfma instructions";
let description = [{
- The `amdgpu.mfma` op is an MLIR wrapper around intrinsics
- for various `mfma` instructions in the CDNA architecture, which perform
- multiple outer products in order to allow fast matrix multiplication.
-
- The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
-
- The `$src`, along with its indices, points to the memory location this thread reads from.
+ The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
+
+ Operands:
+ * `$src`: global memory memref to read from.
+ * `$srcIndices`: indices into `$src` to read from for this thread.
+ * `$dst`: LDS memory memref to write to.
+ * `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
+ number of subgroup size of elements will be written contiguously to `$dst[$dstIndices]`.
+
The `$dst`, along with its indices, points to the memory location the subgroup of this thread
will write to.
+
+ Note: only enabled for gfx942 and later.
}];
let assemblyFormat = [{
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4169eb7825c95..ffb347ade326c 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -925,23 +925,16 @@ struct GlobalLoadLDSOpLowering
// `global_load_lds` instructions.
auto loadWidth = elemSizeInBits / 8;
- // TODO: add chipset support check
- if (chipset.majorVersion >= 12)
- return op.emitOpError("TODO");
+ const Chipset GlobalLoadEnabled{9, 0x4, 0x0};
+ if (chipset < GlobalLoadEnabled)
+ return op.emitOpError("chipset not supported");
- // TODO: fold this into chipset check.
// Currently only 1, 2, and 4 byte loads are supported.
if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
- return op.emitOpError("unsupported element size");
+ return op.emitOpError("chipset unsupported element size");
- Value src = adaptor.getSrc();
- Value dst = adaptor.getDst();
- Value memrefSrc = op.getSrc();
- Value memrefDst = op.getDst();
-
- // Collapse src memref with indices, returns the base pointer and linearized
- // index.
- auto flattenIndex =
+ // Return pair of {base pointer, linearized index}.
+ auto getBasePtrAndLinearizedIndex =
[&](Value memref, MemRefType memrefType,
ValueRange indices) -> std::optional<std::pair<Value, Value>> {
MemRefDescriptor memRefDescriptor(memref);
@@ -955,13 +948,14 @@ struct GlobalLoadLDSOpLowering
getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
};
- // Source
- auto optSrcBuffer = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
- op.getSrcIndices());
+ auto optSrcBuffer = getBasePtrAndLinearizedIndex(
+ adaptor.getSrc(), cast<MemRefType>(op.getSrc().getType()),
+ op.getSrcIndices());
if (!optSrcBuffer)
return op.emitOpError("failed to flatten source memref indices");
- auto optDstBuffer = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
- op.getDstIndices());
+ auto optDstBuffer = getBasePtrAndLinearizedIndex(
+ adaptor.getDst(), cast<MemRefType>(op.getDst().getType()),
+ op.getDstIndices());
if (!optDstBuffer)
return op.emitOpError("failed to flatten destination memref indices");
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 64003e9a2dca6..efec76c91ea23 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -481,11 +481,6 @@ LogicalResult GlobalLoadLDSOp::verify() {
return emitOpError("source memory address space must be Global");
if (dstAddrSpace.getInt() != 3)
return emitOpError("destination memory address space must be Workgroup");
-
- // Check chunk size compatible with element type.
- // TODO
- auto dstChunkSize = dstType.getShape();
-
return success();
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 8871b2ce0eadb..65515dbe30b29 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -9,6 +9,7 @@
// test pass doesn't set up the GPU address space conversions.
#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
// CHECK-LABEL: func @fat_raw_buffer_cast
func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
@@ -461,3 +462,25 @@ func.func @sched_barrier() {
amdgpu.sched_barrier allow = <valu|all_vmem>
func.return
}
+
+// CHECK-LABEL: func @global_load_to_rocdl_f32
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
+func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
+ %c0 = arith.constant 0 : i32
+ %c12 = arith.constant 12 : i32
+ %c32 = arith.constant 32 : i32
+ %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
+ // GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
+ // GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
+ // GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
+ // GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[GLOBAL_OFFSET:.*]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
+ // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[LDS_OFFSET:.*]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+ // GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+ amdgpu.global_load %global[%c12, %c0], %alloc[%c32, %c0] : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+ func.return
+}
More information about the Mlir-commits
mailing list