[Mlir-commits] [mlir] [MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU (PR #133498)
Alan Li
llvmlistbot at llvm.org
Wed Apr 2 17:59:43 PDT 2025
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/133498
>From ba384661402929671a0de82ffd813548af38f3de Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 25 Mar 2025 14:50:45 -0400
Subject: [PATCH 01/11] make it compiable.
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 32 +++++++++++++++++++
1 file changed, 32 insertions(+)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index c0b3e5540b1df..53323be7d6c5c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -668,6 +668,12 @@ def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
+// MI300: limit to only 4 bytes per elements.
+def GlobalLoadTypes : AnyTypeOf<[F16, F32, I8, SI8, UI8, I16, I32,
+ VectorOfLengthAndType<[2], [F16, BF16, I16]>,
+ VectorOfLengthAndType<[2, 4], [I8, SI8, UI8]>
+ ]>;
+
def AMDGPU_MFMAOp :
AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
@@ -765,4 +771,30 @@ def AMDGPU_WMMAOp :
let hasVerifier = 1;
}
+def AMDGPU_GlobalLoadLDSOp :
+ AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
+ Arguments<(ins
+ Arg<AnyMemRef, "buffer to read from", [MemRead]>:$src,
+ Variadic<I32>:$srcIndices,
+ Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
+ Variadic<I32>:$dstIndices
+ )>,
+ Results<(outs)> {
+ let summary = "MLIR wrapper for CDNA mfma instructions";
+ let description = [{
+ The `amdgpu.mfma` op is an MLIR wrapper around intrinsics
+ for various `mfma` instructions in the CDNA architecture, which perform
+ multiple outer products in order to allow fast matrix multiplication.
+
+ The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
+
+ The
+
+ }];
+ let assemblyFormat = [{
+ $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` `,` type($src) `,` type($dst)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // AMDGPU
>From 8e46990cc375af4a2dbf908da1cb79a3a93c138c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 25 Mar 2025 15:38:00 -0400
Subject: [PATCH 02/11] Adding a verifier.
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 2 +-
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 30 +++++++++++++++++++
2 files changed, 31 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 53323be7d6c5c..b7cfed1f18c6c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -792,7 +792,7 @@ def AMDGPU_GlobalLoadLDSOp :
}];
let assemblyFormat = [{
- $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` `,` type($src) `,` type($dst)
+ $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
}];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2b2a167b90c82..64003e9a2dca6 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
@@ -24,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Metadata.h"
#include <limits>
#include <optional>
@@ -459,6 +461,34 @@ LogicalResult DPPOp::verify() {
return success();
}
+LogicalResult GlobalLoadLDSOp::verify() {
+ MemRefType srcType = cast<MemRefType>(getSrc().getType());
+ MemRefType dstType = cast<MemRefType>(getDst().getType());
+
+ if (!memref::isStaticShapeAndContiguousRowMajor(srcType) ||
+ !memref::isStaticShapeAndContiguousRowMajor(dstType))
+ return emitOpError(
+ "source and destination types must have static shape and contiguous");
+
+ // Check $src and $dst element types are the same.
+ if (srcType.getElementType() != dstType.getElementType())
+ return emitOpError("source and destination element types must match");
+
+ // Check $src and $dst memory spaces.
+ auto srcAddrSpace = llvm::dyn_cast<IntegerAttr>(srcType.getMemorySpace());
+ auto dstAddrSpace = llvm::dyn_cast<IntegerAttr>(dstType.getMemorySpace());
+ if (!srcAddrSpace || srcAddrSpace.getInt() != 1)
+ return emitOpError("source memory address space must be Global");
+ if (dstAddrSpace.getInt() != 3)
+ return emitOpError("destination memory address space must be Workgroup");
+
+ // Check chunk size compatible with element type.
+ // TODO
+ auto dstChunkSize = dstType.getShape();
+
+ return success();
+}
+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
>From c072e660b1147a4f0b439f41735edc3df2f07bc5 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 26 Mar 2025 22:30:13 -0400
Subject: [PATCH 03/11] checkpoint
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 13 ++--
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 76 ++++++++++++++++++-
2 files changed, 82 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index b7cfed1f18c6c..e60eaa6d11897 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -771,12 +771,14 @@ def AMDGPU_WMMAOp :
let hasVerifier = 1;
}
+def GlobalLoadMemRefType : MemRefOf<[GlobalLoadTypes]>;
+
def AMDGPU_GlobalLoadLDSOp :
AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
Arguments<(ins
- Arg<AnyMemRef, "buffer to read from", [MemRead]>:$src,
+ Arg<GlobalLoadMemRefType, "buffer to read from", [MemRead]>:$src,
Variadic<I32>:$srcIndices,
- Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
+ Arg<GlobalLoadMemRefType, "buffer to write to", [MemWrite]>:$dst,
Variadic<I32>:$dstIndices
)>,
Results<(outs)> {
@@ -788,11 +790,12 @@ def AMDGPU_GlobalLoadLDSOp :
The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
- The
-
+ The `$src`, along with its indices, points to the memory location this thread reads from.
+ The `$dst`, along with its indices, points to the memory location the subgroup of this thread
+ will write to.
}];
let assemblyFormat = [{
- $src `[` $srcIndices `]` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
+ $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
}];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3acd470cff7f5..83e18a3d87b47 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,6 +903,78 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
+ GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(GlobalLoadLDSOp op, GlobalLoadLDSOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+
+ auto elemType = cast<MemRefType>(op.getDst().getType()).getElementType();
+ size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
+ if (elemSizeInBits % 8 != 0)
+ return op.emitOpError("element size must be a multiple of 8");
+ auto loadWidth = elemSizeInBits / 8;
+
+ // TODO: add chipset support check
+ if (chipset.majorVersion >= 12)
+ return op.emitOpError("TODO");
+
+ // TODO: fold this into chipset check.
+ // Currently only 1, 2, and 4 byte loads are supported.
+ if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
+ return op.emitOpError("unsupported element size");
+
+ Value src = adaptor.getSrc();
+ Value dst = adaptor.getDst();
+ Value memrefSrc = op.getSrc();
+ Value memrefDst = op.getDst();
+
+ // Collapse src memref with indices:
+ auto flattenIndex = [&](Value memref, MemRefType memrefType,
+ ValueRange indices) -> std::optional<Value> {
+ MemRefDescriptor memRefDescriptor(memref);
+ int64_t offset = 0;
+ SmallVector<int64_t, 5> strides;
+ if (failed(memrefType.getStridesAndOffset(strides, offset)))
+ return {};
+ return getLinearIndexI32(rewriter, loc, memRefDescriptor, indices,
+ strides);
+ };
+
+ // Source
+ auto optSrcIdx = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
+ op.getSrcIndices());
+ if (!optSrcIdx)
+ return op.emitOpError("failed to flatten source memref indices");
+ auto optDstIdx = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
+ op.getDstIndices());
+ if (!optDstIdx)
+ return op.emitOpError("failed to flatten destination memref indices");
+
+ Type srcPtrType =
+ LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
+ Type dstPtrType =
+ LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
+ Value srcPtr = rewriter.create<LLVM::GEPOp>(
+ loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
+
+ Value dstPtr = rewriter.create<LLVM::GEPOp>(
+ loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
+
+ rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
+ op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
+ createI32Constant(rewriter, loc, 0),
+ createI32Constant(rewriter, loc, 0));
+
+ return success();
+ }
+};
+
namespace {
struct ExtPackedFp8OpLowering final
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
@@ -1286,6 +1358,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
- chipset);
+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+ GlobalLoadLDSOpLowering>(converter, chipset);
}
>From 564ebc8cd29ffe7437e1b21aacb626874e3a2453 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 27 Mar 2025 10:15:56 -0400
Subject: [PATCH 04/11] Make it work
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 47 +++++++++++--------
1 file changed, 28 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 83e18a3d87b47..4169eb7825c95 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,7 +903,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
-struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
+struct GlobalLoadLDSOpLowering
+ : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
@@ -918,6 +919,10 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
if (elemSizeInBits % 8 != 0)
return op.emitOpError("element size must be a multiple of 8");
+
+ // TODO: instead of only transfering one element per thread, we could
+ // augment it to transfer multiple elements per thread by issuing multiple
+ // `global_load_lds` instructions.
auto loadWidth = elemSizeInBits / 8;
// TODO: add chipset support check
@@ -934,37 +939,41 @@ struct GlobalLoadLDSOpLowering : public ConvertOpToLLVMPattern<GlobalLoadLDSOp>
Value memrefSrc = op.getSrc();
Value memrefDst = op.getDst();
- // Collapse src memref with indices:
- auto flattenIndex = [&](Value memref, MemRefType memrefType,
- ValueRange indices) -> std::optional<Value> {
+ // Collapse src memref with indices, returns the base pointer and linearized
+ // index.
+ auto flattenIndex =
+ [&](Value memref, MemRefType memrefType,
+ ValueRange indices) -> std::optional<std::pair<Value, Value>> {
MemRefDescriptor memRefDescriptor(memref);
int64_t offset = 0;
SmallVector<int64_t, 5> strides;
if (failed(memrefType.getStridesAndOffset(strides, offset)))
return {};
- return getLinearIndexI32(rewriter, loc, memRefDescriptor, indices,
- strides);
+ return std::make_pair(
+ memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
+ memrefType),
+ getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
};
// Source
- auto optSrcIdx = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
- op.getSrcIndices());
- if (!optSrcIdx)
+ auto optSrcBuffer = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
+ op.getSrcIndices());
+ if (!optSrcBuffer)
return op.emitOpError("failed to flatten source memref indices");
- auto optDstIdx = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
- op.getDstIndices());
- if (!optDstIdx)
+ auto optDstBuffer = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
+ op.getDstIndices());
+ if (!optDstBuffer)
return op.emitOpError("failed to flatten destination memref indices");
- Type srcPtrType =
- LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
- Type dstPtrType =
- LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
+ Type srcPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
+ Type dstPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
Value srcPtr = rewriter.create<LLVM::GEPOp>(
- loc, srcPtrType, elemType, src, ArrayRef<Value>({*optSrcIdx}));
-
+ loc, srcPtrType, elemType, optSrcBuffer->first,
+ ArrayRef<Value>({optSrcBuffer->second}));
+
Value dstPtr = rewriter.create<LLVM::GEPOp>(
- loc, dstPtrType, elemType, dst, ArrayRef<Value>({*optDstIdx}));
+ loc, dstPtrType, elemType, optDstBuffer->first,
+ ArrayRef<Value>({optDstBuffer->second}));
rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
>From 92a1ef9d5fa73dbb5c486a1e7dc29290d48b2067 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 27 Mar 2025 13:04:27 -0400
Subject: [PATCH 05/11] update AMDGPU description.
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 18 ++++++-----
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 30 ++++++++-----------
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 5 ----
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 23 ++++++++++++++
4 files changed, 46 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e60eaa6d11897..606cf69537ca1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -784,15 +784,19 @@ def AMDGPU_GlobalLoadLDSOp :
Results<(outs)> {
let summary = "MLIR wrapper for CDNA mfma instructions";
let description = [{
- The `amdgpu.mfma` op is an MLIR wrapper around intrinsics
- for various `mfma` instructions in the CDNA architecture, which perform
- multiple outer products in order to allow fast matrix multiplication.
-
- The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
-
- The `$src`, along with its indices, points to the memory location this thread reads from.
+ The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
+
+ Operands:
+ * `$src`: global memory memref to read from.
+ * `$srcIndices`: indices into `$src` to read from for this thread.
+ * `$dst`: LDS memory memref to write to.
+ * `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
+ number of subgroup size of elements will be written contiguously to `$dst[$dstIndices]`.
+
The `$dst`, along with its indices, points to the memory location the subgroup of this thread
will write to.
+
+ Note: only enabled for gfx942 and later.
}];
let assemblyFormat = [{
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4169eb7825c95..ffb347ade326c 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -925,23 +925,16 @@ struct GlobalLoadLDSOpLowering
// `global_load_lds` instructions.
auto loadWidth = elemSizeInBits / 8;
- // TODO: add chipset support check
- if (chipset.majorVersion >= 12)
- return op.emitOpError("TODO");
+ const Chipset GlobalLoadEnabled{9, 0x4, 0x0};
+ if (chipset < GlobalLoadEnabled)
+ return op.emitOpError("chipset not supported");
- // TODO: fold this into chipset check.
// Currently only 1, 2, and 4 byte loads are supported.
if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
- return op.emitOpError("unsupported element size");
+ return op.emitOpError("chipset unsupported element size");
- Value src = adaptor.getSrc();
- Value dst = adaptor.getDst();
- Value memrefSrc = op.getSrc();
- Value memrefDst = op.getDst();
-
- // Collapse src memref with indices, returns the base pointer and linearized
- // index.
- auto flattenIndex =
+ // Return pair of {base pointer, linearized index}.
+ auto getBasePtrAndLinearizedIndex =
[&](Value memref, MemRefType memrefType,
ValueRange indices) -> std::optional<std::pair<Value, Value>> {
MemRefDescriptor memRefDescriptor(memref);
@@ -955,13 +948,14 @@ struct GlobalLoadLDSOpLowering
getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
};
- // Source
- auto optSrcBuffer = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
- op.getSrcIndices());
+ auto optSrcBuffer = getBasePtrAndLinearizedIndex(
+ adaptor.getSrc(), cast<MemRefType>(op.getSrc().getType()),
+ op.getSrcIndices());
if (!optSrcBuffer)
return op.emitOpError("failed to flatten source memref indices");
- auto optDstBuffer = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
- op.getDstIndices());
+ auto optDstBuffer = getBasePtrAndLinearizedIndex(
+ adaptor.getDst(), cast<MemRefType>(op.getDst().getType()),
+ op.getDstIndices());
if (!optDstBuffer)
return op.emitOpError("failed to flatten destination memref indices");
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 64003e9a2dca6..efec76c91ea23 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -481,11 +481,6 @@ LogicalResult GlobalLoadLDSOp::verify() {
return emitOpError("source memory address space must be Global");
if (dstAddrSpace.getInt() != 3)
return emitOpError("destination memory address space must be Workgroup");
-
- // Check chunk size compatible with element type.
- // TODO
- auto dstChunkSize = dstType.getShape();
-
return success();
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 8871b2ce0eadb..65515dbe30b29 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -9,6 +9,7 @@
// test pass doesn't set up the GPU address space conversions.
#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
// CHECK-LABEL: func @fat_raw_buffer_cast
func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
@@ -461,3 +462,25 @@ func.func @sched_barrier() {
amdgpu.sched_barrier allow = <valu|all_vmem>
func.return
}
+
+// CHECK-LABEL: func @global_load_to_rocdl_f32
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
+func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
+ %c0 = arith.constant 0 : i32
+ %c12 = arith.constant 12 : i32
+ %c32 = arith.constant 32 : i32
+ %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
+ // GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
+ // GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
+ // GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
+ // GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[GLOBAL_OFFSET:.*]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
+ // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[LDS_OFFSET:.*]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+ // GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+ amdgpu.global_load %global[%c12, %c0], %alloc[%c32, %c0] : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+ func.return
+}
>From a17e854c5f2f72e486f7fc53a677afa1e708c700 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 11:21:10 -0400
Subject: [PATCH 06/11] Address comments.
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 16 +++---------
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 15 ++++++-----
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 25 ++++++++++---------
3 files changed, 24 insertions(+), 32 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 606cf69537ca1..70d5f5de416d1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -668,12 +668,6 @@ def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
-// MI300: limit to only 4 bytes per elements.
-def GlobalLoadTypes : AnyTypeOf<[F16, F32, I8, SI8, UI8, I16, I32,
- VectorOfLengthAndType<[2], [F16, BF16, I16]>,
- VectorOfLengthAndType<[2, 4], [I8, SI8, UI8]>
- ]>;
-
def AMDGPU_MFMAOp :
AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
@@ -771,14 +765,12 @@ def AMDGPU_WMMAOp :
let hasVerifier = 1;
}
-def GlobalLoadMemRefType : MemRefOf<[GlobalLoadTypes]>;
-
-def AMDGPU_GlobalLoadLDSOp :
- AMDGPU_Op<"global_load", [SameVariadicOperandSize]>,
+def AMDGPU_GatherToLDSOp :
+ AMDGPU_Op<"gather_to_lds", [SameVariadicOperandSize]>,
Arguments<(ins
- Arg<GlobalLoadMemRefType, "buffer to read from", [MemRead]>:$src,
+ Arg<AnyMemRef, "buffer to gather from", [MemRead]>:$src,
Variadic<I32>:$srcIndices,
- Arg<GlobalLoadMemRefType, "buffer to write to", [MemWrite]>:$dst,
+ Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
Variadic<I32>:$dstIndices
)>,
Results<(outs)> {
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ffb347ade326c..5ade785ea9b93 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,15 +903,15 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
-struct GlobalLoadLDSOpLowering
- : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
- GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
- : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
+struct GatherToLDSOpLowering
+ : public ConvertOpToLLVMPattern<GatherToLDSOp> {
+ GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
- matchAndRewrite(GlobalLoadLDSOp op, GlobalLoadLDSOpAdaptor adaptor,
+ matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
@@ -925,8 +925,7 @@ struct GlobalLoadLDSOpLowering
// `global_load_lds` instructions.
auto loadWidth = elemSizeInBits / 8;
- const Chipset GlobalLoadEnabled{9, 0x4, 0x0};
- if (chipset < GlobalLoadEnabled)
+ if (chipset < kGfx942)
return op.emitOpError("chipset not supported");
// Currently only 1, 2, and 4 byte loads are supported.
@@ -1362,5 +1361,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
- GlobalLoadLDSOpLowering>(converter, chipset);
+ GatherToLDSOpLowering>(converter, chipset);
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index efec76c91ea23..27e0cfb8044b3 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -25,7 +25,6 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/IR/Metadata.h"
#include <limits>
#include <optional>
@@ -461,26 +460,28 @@ LogicalResult DPPOp::verify() {
return success();
}
-LogicalResult GlobalLoadLDSOp::verify() {
+LogicalResult GatherToLDSOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
MemRefType dstType = cast<MemRefType>(getDst().getType());
- if (!memref::isStaticShapeAndContiguousRowMajor(srcType) ||
- !memref::isStaticShapeAndContiguousRowMajor(dstType))
+ if (!memref::isStaticShapeAndContiguousRowMajor(dstType))
return emitOpError(
- "source and destination types must have static shape and contiguous");
+ "destination types must have static shape and contiguous");
+ auto elemType = srcType.getElementType();
// Check $src and $dst element types are the same.
- if (srcType.getElementType() != dstType.getElementType())
+ if (elemType != dstType.getElementType())
return emitOpError("source and destination element types must match");
- // Check $src and $dst memory spaces.
- auto srcAddrSpace = llvm::dyn_cast<IntegerAttr>(srcType.getMemorySpace());
- auto dstAddrSpace = llvm::dyn_cast<IntegerAttr>(dstType.getMemorySpace());
- if (!srcAddrSpace || srcAddrSpace.getInt() != 1)
- return emitOpError("source memory address space must be Global");
- if (dstAddrSpace.getInt() != 3)
+ // Element type sizes should be 1, 2, or 4 bytes.
+ auto elemSize = elemType.getIntOrFloatBitWidth();
+ if (elemSize != 8 && elemSize != 16 && elemSize != 32)
+ return emitOpError("source and destination element types must be 8, 16, "
+ "or 32 bits");
+
+ if (!gpu::GPUDialect::hasWorkgroupMemoryAddressSpace(dstType))
return emitOpError("destination memory address space must be Workgroup");
+
return success();
}
>From 4ed20068c55450581929686679197ffd999c68ab Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 14:12:42 -0400
Subject: [PATCH 07/11] update lowering
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 10 ++-
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 73 ++++++++-----------
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 49 +++++++++----
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 12 +--
4 files changed, 78 insertions(+), 66 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 70d5f5de416d1..350b184dae6a2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -769,9 +769,10 @@ def AMDGPU_GatherToLDSOp :
AMDGPU_Op<"gather_to_lds", [SameVariadicOperandSize]>,
Arguments<(ins
Arg<AnyMemRef, "buffer to gather from", [MemRead]>:$src,
- Variadic<I32>:$srcIndices,
+ Variadic<Index>:$srcIndices,
Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
- Variadic<I32>:$dstIndices
+ Variadic<Index>:$dstIndices,
+ TypeAttr:$transferType
)>,
Results<(outs)> {
let summary = "MLIR wrapper for CDNA mfma instructions";
@@ -784,7 +785,10 @@ def AMDGPU_GatherToLDSOp :
* `$dst`: LDS memory memref to write to.
* `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
number of subgroup size of elements will be written contiguously to `$dst[$dstIndices]`.
-
+ * `$transferType`: type of the data to be transferred by each thread. This is used to determine
+ the size of the data to be transferred and the number of threads in the subgroup.
+ The transfer type must be a scalar type or a vector type with a single element type.
+
The `$dst`, along with its indices, points to the memory location the subgroup of this thread
will write to.
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5ade785ea9b93..a644c24a8d6bc 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -913,60 +913,49 @@ struct GatherToLDSOpLowering
LogicalResult
matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (chipset < kGfx942)
+ return op.emitOpError("chipset not supported");
+
Location loc = op.getLoc();
- auto elemType = cast<MemRefType>(op.getDst().getType()).getElementType();
- size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
- if (elemSizeInBits % 8 != 0)
- return op.emitOpError("element size must be a multiple of 8");
+ auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+ auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
// TODO: instead of only transfering one element per thread, we could
// augment it to transfer multiple elements per thread by issuing multiple
// `global_load_lds` instructions.
- auto loadWidth = elemSizeInBits / 8;
-
- if (chipset < kGfx942)
- return op.emitOpError("chipset not supported");
+ size_t loadWidth;
+ Type transferType = op.getTransferType();
+ if (auto transferVectorType = dyn_cast<VectorType>(transferType))
+ loadWidth = transferVectorType.getNumElements() *
+ transferVectorType.getElementTypeBitWidth() / 8;
+ else
+ loadWidth = transferType.getIntOrFloatBitWidth() / 8;
// Currently only 1, 2, and 4 byte loads are supported.
- if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
+ if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
return op.emitOpError("chipset unsupported element size");
- // Return pair of {base pointer, linearized index}.
- auto getBasePtrAndLinearizedIndex =
- [&](Value memref, MemRefType memrefType,
- ValueRange indices) -> std::optional<std::pair<Value, Value>> {
- MemRefDescriptor memRefDescriptor(memref);
- int64_t offset = 0;
- SmallVector<int64_t, 5> strides;
- if (failed(memrefType.getStridesAndOffset(strides, offset)))
- return {};
- return std::make_pair(
- memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
- memrefType),
- getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
+ auto convertIndices =
+ [&](ValueRange indices) -> SmallVector<Value, 4> {
+ SmallVector<Value, 4> convertedIndices;
+
+ for (Value index : indices) {
+ Type convertedType = getTypeConverter()->convertType(index.getType());
+ auto convertedIndex = rewriter.create<LLVM::ConstantOp>(
+ loc, convertedType,
+ rewriter.getIntegerAttr(convertedType, 0));
+ convertedIndices.push_back(convertedIndex);
+ }
+ return convertedIndices;
};
- auto optSrcBuffer = getBasePtrAndLinearizedIndex(
- adaptor.getSrc(), cast<MemRefType>(op.getSrc().getType()),
- op.getSrcIndices());
- if (!optSrcBuffer)
- return op.emitOpError("failed to flatten source memref indices");
- auto optDstBuffer = getBasePtrAndLinearizedIndex(
- adaptor.getDst(), cast<MemRefType>(op.getDst().getType()),
- op.getDstIndices());
- if (!optDstBuffer)
- return op.emitOpError("failed to flatten destination memref indices");
-
- Type srcPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
- Type dstPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
- Value srcPtr = rewriter.create<LLVM::GEPOp>(
- loc, srcPtrType, elemType, optSrcBuffer->first,
- ArrayRef<Value>({optSrcBuffer->second}));
-
- Value dstPtr = rewriter.create<LLVM::GEPOp>(
- loc, dstPtrType, elemType, optDstBuffer->first,
- ArrayRef<Value>({optDstBuffer->second}));
+ Value srcPtr =
+ getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
+ convertIndices(op.getSrcIndices()), rewriter);
+ Value dstPtr =
+ getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
+ convertIndices(op.getDstIndices()), rewriter);
rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 27e0cfb8044b3..605f0a92224e9 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/DerivedTypes.h"
#include <limits>
#include <optional>
@@ -113,21 +114,30 @@ LogicalResult FatRawBufferCastOp::verify() {
return success();
}
+static bool hasGlobalMemorySpace(Attribute memorySpace) {
+ if (!memorySpace)
+ return true;
+ else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
+ return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
+ else if (auto gpuMemorySpace =
+ llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+ return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+ return false;
+}
+
+static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
+ if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
+ return intMemorySpace.getInt() == 3;
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// RawBuffer*Op
//===----------------------------------------------------------------------===//
template <typename T>
static LogicalResult verifyRawBufferOp(T &op) {
MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
- Attribute memorySpace = bufferType.getMemorySpace();
- bool isGlobal = false;
- if (!memorySpace)
- isGlobal = true;
- else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
- isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
- else if (auto gpuMemorySpace =
- llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
- isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+ bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
if (!isGlobal)
return op.emitOpError(
@@ -473,13 +483,22 @@ LogicalResult GatherToLDSOp::verify() {
if (elemType != dstType.getElementType())
return emitOpError("source and destination element types must match");
- // Element type sizes should be 1, 2, or 4 bytes.
- auto elemSize = elemType.getIntOrFloatBitWidth();
- if (elemSize != 8 && elemSize != 16 && elemSize != 32)
- return emitOpError("source and destination element types must be 8, 16, "
- "or 32 bits");
+ // copy type sizes should be 1, 2, or 4 bytes.
+ auto transferType = getTransferType();
+ size_t transferSize;
+ if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
+ transferSize = vectorTransfer.getNumElements() *
+ vectorTransfer.getElementTypeBitWidth();
+ } else {
+ transferSize = transferType.getIntOrFloatBitWidth();
+ }
+ if (transferSize != 8 && transferSize != 16 && transferSize != 32)
+ return emitOpError("Transfering type size must be 8, 16, or 32 bits");
+
+ if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
+ return emitOpError("source memory address space must be Global");
- if (!gpu::GPUDialect::hasWorkgroupMemoryAddressSpace(dstType))
+ if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
return emitOpError("destination memory address space must be Workgroup");
return success();
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 65515dbe30b29..5dad6b75f9ec8 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -466,21 +466,21 @@ func.func @sched_barrier() {
// CHECK-LABEL: func @global_load_to_rocdl_f32
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
- %c0 = arith.constant 0 : i32
- %c12 = arith.constant 12 : i32
- %c32 = arith.constant 32 : i32
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c32 = arith.constant 32 : index
%alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
// GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
// GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
// GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
// GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
- // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[GLOBAL_OFFSET:.*]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
- // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[LDS_OFFSET:.*]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+ // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]]
+ // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]]
// GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
// GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
// GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
- amdgpu.global_load %global[%c12, %c0], %alloc[%c32, %c0] : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+ amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32} : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
func.return
}
>From 81bb8bc2aa6b75aabe6a0d50436cabca90287691 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 14:37:42 -0400
Subject: [PATCH 08/11] update test files.
---
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 23 ----------------
.../Conversion/AMDGPUToROCDL/load_lds.mlir | 26 +++++++++++++++++++
2 files changed, 26 insertions(+), 23 deletions(-)
create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 5dad6b75f9ec8..8871b2ce0eadb 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -9,7 +9,6 @@
// test pass doesn't set up the GPU address space conversions.
#gpu_global_addrspace = 1
-#gpu_lds_addrspace = 3
// CHECK-LABEL: func @fat_raw_buffer_cast
func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
@@ -462,25 +461,3 @@ func.func @sched_barrier() {
amdgpu.sched_barrier allow = <valu|all_vmem>
func.return
}
-
-// CHECK-LABEL: func @global_load_to_rocdl_f32
-// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
-func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
- %c0 = arith.constant 0 : index
- %c12 = arith.constant 12 : index
- %c32 = arith.constant 32 : index
- %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
- // GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
- // GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
- // GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
- // GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
- // GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
- // GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]]
- // GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]]
- // GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
- // GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
- // GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
- // GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
- amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32} : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
- func.return
-}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
new file mode 100644
index 0000000000000..507928eaf0d78
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s
+
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+
+// CHECK-LABEL: func @global_load_to_rocdl_f32
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
+func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c32 = arith.constant 32 : index
+ %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
+ // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
+ // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
+ // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]]
+ // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+ // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]]
+ // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+ amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32} : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+ func.return
+}
>From 73629f4f09c581e027943804270c582621bc7192 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 16:09:58 -0400
Subject: [PATCH 09/11] linting
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 11 ++++-------
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 2 +-
2 files changed, 5 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index a644c24a8d6bc..5e8cbf262f0eb 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -903,8 +903,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
-struct GatherToLDSOpLowering
- : public ConvertOpToLLVMPattern<GatherToLDSOp> {
+struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -936,15 +935,13 @@ struct GatherToLDSOpLowering
if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
return op.emitOpError("chipset unsupported element size");
- auto convertIndices =
- [&](ValueRange indices) -> SmallVector<Value, 4> {
+ auto convertIndices = [&](ValueRange indices) -> SmallVector<Value, 4> {
SmallVector<Value, 4> convertedIndices;
-
+
for (Value index : indices) {
Type convertedType = getTypeConverter()->convertType(index.getType());
auto convertedIndex = rewriter.create<LLVM::ConstantOp>(
- loc, convertedType,
- rewriter.getIntegerAttr(convertedType, 0));
+ loc, convertedType, rewriter.getIntegerAttr(convertedType, 0));
convertedIndices.push_back(convertedIndex);
}
return convertedIndices;
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 605f0a92224e9..00d67f58aee1c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -120,7 +120,7 @@ static bool hasGlobalMemorySpace(Attribute memorySpace) {
else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
else if (auto gpuMemorySpace =
- llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+ llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
return false;
}
>From b48370161416b7ddce8734664162d282ca4c2bcb Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 20:52:19 -0400
Subject: [PATCH 10/11] Update tests.
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 22 +--
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 7 +-
.../Conversion/AMDGPUToROCDL/load_lds.mlir | 137 +++++++++++++++++-
3 files changed, 139 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5e8cbf262f0eb..7efe70fd16a04 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -935,24 +935,10 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
return op.emitOpError("chipset unsupported element size");
- auto convertIndices = [&](ValueRange indices) -> SmallVector<Value, 4> {
- SmallVector<Value, 4> convertedIndices;
-
- for (Value index : indices) {
- Type convertedType = getTypeConverter()->convertType(index.getType());
- auto convertedIndex = rewriter.create<LLVM::ConstantOp>(
- loc, convertedType, rewriter.getIntegerAttr(convertedType, 0));
- convertedIndices.push_back(convertedIndex);
- }
- return convertedIndices;
- };
-
- Value srcPtr =
- getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
- convertIndices(op.getSrcIndices()), rewriter);
- Value dstPtr =
- getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
- convertIndices(op.getDstIndices()), rewriter);
+ Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
+ (adaptor.getSrcIndices()), rewriter);
+ Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
+ (adaptor.getDstIndices()), rewriter);
rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 00d67f58aee1c..6309dac761e4f 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -117,10 +117,9 @@ LogicalResult FatRawBufferCastOp::verify() {
static bool hasGlobalMemorySpace(Attribute memorySpace) {
if (!memorySpace)
return true;
- else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
+ if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
- else if (auto gpuMemorySpace =
- llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+ if (auto gpuMemorySpace = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
return false;
}
@@ -128,6 +127,8 @@ static bool hasGlobalMemorySpace(Attribute memorySpace) {
static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
return intMemorySpace.getInt() == 3;
+ if (auto gpuMemorySpace = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+ return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
return false;
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
index 507928eaf0d78..c887f6504172a 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -10,17 +10,142 @@ func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_add
%c12 = arith.constant 12 : index
%c32 = arith.constant 32 : index
%alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
- // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
- // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+ // CHECK: %[[C12:.*]] = arith.constant 12 : index
+ // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+ // CHECK: %[[ALLOC:.*]] = memref.alloc()
+ // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
+ // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
+
+ // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+ // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+ // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+ // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
+ // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+
+ // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+ // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+ // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+ // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
+ // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+ amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32}
+ : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+ func.return
+}
+
+// CHECK-LABEL: func @global_load_to_rocdl_i8
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi8, 1>)
+func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrspace>) {
+ // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+ // CHECK: %[[C12:.*]] = arith.constant 12 : index
+ // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+ // CHECK: %[[ALLOC:.*]] = memref.alloc()
+ // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
+ // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
+
+ // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+ // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+ // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+ // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
+ // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+
+ // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+ // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+ // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+ // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
+ // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C1]], %[[C0]], %[[C0_2]]
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c32 = arith.constant 32 : index
+ %alloc = memref.alloc() : memref<64x64xi8, #gpu_lds_addrspace>
+ amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = i8}
+ : memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace>
+ func.return
+}
+
+// CHECK-LABEL: func @global_load_to_rocdl_vec
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi16, 1>)
+func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_addrspace>) {
+ // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+ // CHECK: %[[C12:.*]] = arith.constant 12 : index
+ // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+ // CHECK: %[[ALLOC:.*]] = memref.alloc()
+ // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
- // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]]
+
+ // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+ // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+ // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+ // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
- // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]]
+
+ // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+ // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+ // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+ // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
- amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32} : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c32 = arith.constant 32 : index
+ %alloc = memref.alloc() : memref<64x128xi16, #gpu_lds_addrspace>
+ amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = vector<2 x i16>}
+ : memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace>
func.return
}
+
+
+// CHECK-LABEL: func @global_load_to_rocdl_dynamic_indices
+// CHECK-SAME: (%[[ARG0:.*]]: memref<512xi32, 1>, %[[SRC_IDX:.*]]: index, %[[DST_IDX:.*]]: index)
+func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_global_addrspace>, %src_idx : index, %dst_idx : index) {
+ // CHECK: %[[DSTIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[DST_IDX]]
+ // CHECK: %[[SRCIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC_IDX]]
+ // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+ // CHECK: %[[ALLOC:.*]] = memref.alloc()
+ // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
+ // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
+ // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]]
+ // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+ // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX_CAST]]]
+ // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
+ %alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
+ %c0 = arith.constant 0 : index
+ amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0] {transferType = i32}
+ : memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
+ func.return
+}
\ No newline at end of file
>From 1a40d6c20698052a569c8f1a49425ed6acb3834c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 2 Apr 2025 20:59:25 -0400
Subject: [PATCH 11/11] update again
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 2 +-
mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir | 16 ++++++++--------
2 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 350b184dae6a2..52753e2faf901 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -795,7 +795,7 @@ def AMDGPU_GatherToLDSOp :
Note: only enabled for gfx942 and later.
}];
let assemblyFormat = [{
- $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)
+ $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
}];
let hasVerifier = 1;
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
index c887f6504172a..160e5b203ed1f 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -39,8 +39,8 @@ func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_add
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
- amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = f32}
- : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+ amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+ : f32, memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
func.return
}
@@ -80,8 +80,8 @@ func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrs
%c12 = arith.constant 12 : index
%c32 = arith.constant 32 : index
%alloc = memref.alloc() : memref<64x64xi8, #gpu_lds_addrspace>
- amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = i8}
- : memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace>
+ amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+ : i8, memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace>
func.return
}
@@ -121,8 +121,8 @@ func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_add
%c12 = arith.constant 12 : index
%c32 = arith.constant 32 : index
%alloc = memref.alloc() : memref<64x128xi16, #gpu_lds_addrspace>
- amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] {transferType = vector<2 x i16>}
- : memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace>
+ amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+ : vector<2 x i16>, memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace>
func.return
}
@@ -145,7 +145,7 @@ func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_g
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
%alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
%c0 = arith.constant 0 : index
- amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0] {transferType = i32}
- : memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
+ amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0]
+ : i32, memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
func.return
}
\ No newline at end of file
More information about the Mlir-commits
mailing list