[Mlir-commits] [mlir] [mlir][AMDGPU] Plumb address space 7 through MLIR, add address_space attr. (PR #125594)
Krzysztof Drewniak
llvmlistbot at llvm.org
Thu Feb 20 10:10:47 PST 2025
https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/125594
>From ceed9d2674bcf0c5893502763ab493bec6bf2efb Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 3 Feb 2025 22:49:38 +0000
Subject: [PATCH] [mlir][AMDGPU] Plumb address space 7 through MLIR, add
address_space attr.
This commit adds support for casting memrefs into fat raw buffer
pointers to the AMDGPU dialect.
Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow
encapsulating a buffer descriptor (as produced by the make.buffer.rsrc
intrinsic or provided from some API) into a pointer that supports
ordinary pointer operations like load or store. This allows people to
take advantage of the additional semantics that buffer_load and
similar instructions provide without forcing the use of entirely
separate amdgpu.raw_buffer_* operations.
Operations on fat raw buffer pointers are translated to the
corresponding LLVM intrinsics by the backend.
This commit also goes and and defines a #amdgpu.address_space<>
attribute so that AMDGPU-specific memory spaces can be represented.
Only #amdgpu.address_space<fat_raw_buffer> will work correctly with
the memref dialect, but the other possible address spaces are included
for completeness.
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
Co-authored-by: Prashant Kumar <pk5561 at gmail.com>
---
.../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h | 16 +-
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 105 +++++++
.../mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h | 2 +
.../mlir/Dialect/AMDGPU/Transforms/Passes.h | 3 +
.../mlir/Dialect/AMDGPU/Transforms/Passes.td | 16 ++
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 258 +++++++++++++-----
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 53 ++++
.../Dialect/AMDGPU/Transforms/CMakeLists.txt | 2 +
.../Transforms/ResolveStridedMetadata.cpp | 79 ++++++
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 122 ++++++++-
.../amdgpu-resolve-strided-metadata.mlir | 51 ++++
mlir/test/Dialect/AMDGPU/invalid.mlir | 25 ++
mlir/test/Dialect/AMDGPU/ops.mlir | 19 ++
13 files changed, 680 insertions(+), 71 deletions(-)
create mode 100644 mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
create mode 100644 mlir/test/Dialect/AMDGPU/amdgpu-resolve-strided-metadata.mlir
diff --git a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
index cc32e97084830..b550980c4ad01 100644
--- a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
+++ b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
@@ -16,18 +16,26 @@ namespace mlir {
class LLVMTypeConverter;
class RewritePatternSet;
+class TypeConverter;
class Pass;
#define GEN_PASS_DECL_CONVERTAMDGPUTOROCDLPASS
#include "mlir/Conversion/Passes.h.inc"
-/// Note: The ROCDL target does not support the LLVM bfloat type at this time
-/// and so this function will add conversions to change all `bfloat` uses
-/// to `i16`.
-void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter,
+/// Note: This function will also add conversions for the AMDGPU-specific
+/// address spaces, but those can be added separately using
+/// populateAMDGPUMemorySpaceAttributeConversions().
+void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
amdgpu::Chipset chipset);
+/// Remap AMDGPU memory spaces to LLVM address spaces
+/// by mapping amdgpu::AddressSpace::fat_raw_buffer to ptr addrspace(7),
+/// amdgpu::AddressSpace::buffer_rsrc to ptr addrspace(8), and
+/// amdgpu::AddressSpace::fat_strided_buffer to ptr addrspace(9).
+void populateAMDGPUMemorySpaceAttributeConversions(
+ TypeConverter &typeConverter);
+
} // namespace mlir
#endif // MLIR_CONVERSION_AMDGPUTOROCDL_AMDGPUTOROCDL_H_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index f795dd89b79a1..6998e7b8ada5b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -9,8 +9,11 @@
#ifndef AMDGPU
#define AMDGPU
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/EnumAttr.td"
+include "mlir/IR/Properties.td"
include "mlir/IR/OpBase.td"
def AMDGPU_Dialect : Dialect {
@@ -32,6 +35,45 @@ def AMDGPU_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
}
+//===----------------------------------------------------------------------===//
+// AMDGPU general attribute definitions
+//===----------------------------------------------------------------------===//
+
+def AMDGPU_AddressSpace : I32EnumAttr<"AddressSpace",
+ "AMDGPU-specific address spaces",
+ [
+ I32EnumAttrCase<"FatRawBuffer", 0, "fat_raw_buffer">,
+ I32EnumAttrCase<"BufferRsrc", 1, "buffer_rsrc">,
+ I32EnumAttrCase<"FatStructuredBuffer", 2, "fat_structured_buffer">,
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
+ "address_space"> {
+ let description = [{
+ AMDGPU-specific memory spaces that may not have exact analogues on other
+ GPU targets or backends.
+
+ - `fat_raw_buffer` is the memory space used when a memref is stored as
+ as a "buffer fat pointer" - that is, a buffer resource (that is set up to
+ use raw byte-level indexing) along with its offset. The AMDGPU backend
+ implements `ptr addrspace(7)` to represent these fat pointers so that
+ buffer resources (which allow advanced features like bounds checking or
+ cache swizzling) can be used like ordinary LLVM pointers or memrefs.
+ See also the `fat_raw_buffer_cast` operation
+ - `buffer_rsrc` is the memory space for `ptr addrspace(8)`, representing a
+ buffer resource. It should not be used for memrefs, since it does not support
+ indexing
+ - `fat_structured_buffer` represents `ptr addrspace(9)`, a buffer resource
+ that carries both an index and offset field, which are used for complex
+ structured indexing that is primarily seen in graphics applications. This
+ is also incompatible with the simple indexing model supported by memref.
+ }];
+ let assemblyFormat = "`<` $value `>`";
+}
+
//===----------------------------------------------------------------------===//
// AMDGPU Op definitions
//===----------------------------------------------------------------------===//
@@ -118,6 +160,69 @@ def AMDGPU_PackedStochRoundFp8Op :
let hasVerifier = 1;
}
+def AMDGPU_FatRawBufferCastOp :
+ AMDGPU_Op<"fat_raw_buffer_cast",
+ [Pure,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>,
+ ViewLikeOpInterface, AttrSizedOperandSegments]>,
+ Arguments<(ins AnyMemRef:$source,
+ Optional<I32>:$validBytes,
+ Optional<I<14>>:$cacheSwizzleStride,
+ DefaultValuedProp<BoolProp, "true">:$boundsCheck,
+ UnitProp:$resetOffset)>,
+ Results<(outs AnyMemRef:$result)> {
+ let summary = "Create a raw buffer fat pointer that matches `memref`";
+ let description = [{
+ Wraps the memory pointed to by `source` as a raw buffer fat pointer, or,
+ in LLVM terms, a `ptr addrspace(7)`, returning a memref that has the same
+ sizes and layout but the `#amdgpu.address_space<fat_raw_buffer>`
+ address space.
+
+ This memref can be used with standard memref operations like `memref.load`,
+ `memref.store`, and `memref.atomicrmw`, which will be lowered to the relevant
+ buffer intrinsics. (`vector.masked_load/store` will work once there's backend
+ support for lowering them, and then this document will be updated)
+
+ If `validBytes` is given, it is the number of bytes that will be valid as
+ an offset to `out`. If it is not provided, this will be inferred from
+ the size of the memref during lowering. This size is
+ max_d (sizes[d] * strides[d]) * sizeof(element type)..
+
+ The flags of the buffer descriptor will be set up to enable raw usage -
+ for example, stride = 0, add_tid = 0, and so on. The `boundsCheck`
+ property determines if bounds checking is enabled or not (on architectures
+ where this can be controlled - that is, on RDNA chips).
+
+ If `cacheSwizzleStride` is provided, L1 cache swizzling will be enabled
+ on architectures that support it. This swizzling, unlike the main swizzling
+ mode (whose usage makes a buffer non-raw) does not affect index calculation,
+ but does affect cache behavior. Mixing access between cache-swizzled raw
+ buffers and other forms of memory access, like ordinary pointer loads or
+ unswizzled buffer pointers can cause incorrect behavior and must be avoided.
+
+ This operation preserves the sizes, strides, and offset of the input
+ memref - they'll be added in by `memref.load` later. However, if
+ `resetOffset` is set, that offset will be added to the base pointer.
+ If the value of the memref's offset is not uniform (independent of the lane/thread ID),
+ this will lead to substantially decreased performance due to the need for
+ a waterfall loop on the base address of the buffer resource.
+ }];
+
+ let extraClassDeclaration = [{
+ Value getViewSource() { return getSource(); }
+ }];
+
+ let assemblyFormat = [{
+ $source oilist (`validBytes` `(` $validBytes `)`
+ | `cacheSwizzleStride` `(` $cacheSwizzleStride `)`
+ | `boundsCheck` `(` $boundsCheck `)`
+ | `resetOffset` $resetOffset )
+ attr-dict `:` type($source) `to` type($result)
+ }];
+
+ let hasVerifier = 1;
+}
+
/// Raw buffer load
def AMDGPU_RawBufferLoadOp :
AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 0a2e6bb5e9fe4..3de57c923178a 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -18,7 +18,9 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc"
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index 8dd5ff1a4b198..c3ae7930e8ec8 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -21,12 +21,15 @@ class ConversionTarget;
namespace amdgpu {
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
+#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
RewritePatternSet &patterns,
Chipset chipset);
+
+void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns);
} // namespace amdgpu
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 23f8b8f653b67..ef50367b67a21 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -31,4 +31,20 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
"Chipset that these operations will run on">];
}
+def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
+ let summary = "Resolve memref.extract_strided_metadata on AMDGPU ops";
+ let description = [{
+ This pass rrewrites `memref.extract_strided_metadata` operations
+ targeting the AMDGPU dialect casts.
+
+ It's mainly meant for testing - please incorporate the patterns into your
+ own extract-strided-metadata passes (or run memref's expand-strided-metadata
+ again after this).
+ }];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "memref::MemRefDialect"
+ ];
+}
+
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b29228ef87ea7..357e6bd0bca98 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -19,6 +19,8 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
+#include "../LLVMCommon/MemRefDescriptor.h"
+
#include "llvm/ADT/STLExtras.h"
#include <optional>
@@ -30,6 +32,11 @@ namespace mlir {
using namespace mlir;
using namespace mlir::amdgpu;
+// Define commonly used chipsets versions for convenience.
+constexpr Chipset kGfx908 = Chipset(9, 0, 8);
+constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
+constexpr Chipset kGfx942 = Chipset(9, 4, 2);
+
/// Convert an unsigned number `val` to i32.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
Location loc, Value val) {
@@ -76,11 +83,166 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
return index ? index : createI32Constant(rewriter, loc, 0);
}
+/// Compute the contents of the `num_records` field for a given memref
+/// descriptor - that is, the number of bytes that's one element past the
+/// greatest possible valid index into the memref.
+static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
+ MemRefType memrefType,
+ MemRefDescriptor &memrefDescriptor,
+ ArrayRef<int64_t> strides,
+ uint32_t elementByteWidth) {
+ if (memrefType.hasStaticShape() &&
+ !llvm::any_of(strides, ShapedType::isDynamic)) {
+ int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+ ArrayRef<int64_t> shape = memrefType.getShape();
+ for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+ size = std::max(shape[i] * strides[i], size);
+ size = size * elementByteWidth;
+ assert(size < std::numeric_limits<uint32_t>::max() &&
+ "the memref buffer is too large");
+ return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
+ }
+ Value maxIndex;
+ for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
+ Value size = memrefDescriptor.size(rewriter, loc, i);
+ Value stride = memrefDescriptor.stride(rewriter, loc, i);
+ Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+ maxIndex = maxIndex
+ ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+ : maxThisDim;
+ }
+ Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
+ Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
+ return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
+}
+
+static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
+ Value basePointer, Value numRecords,
+ bool boundsCheck, amdgpu::Chipset chipset,
+ Value cacheSwizzleStride = nullptr,
+ unsigned addressSpace = 8) {
+ // The stride value is generally 0. However, on MI-300 and onward, you can
+ // enable a cache swizzling mode by setting bit 14 of the stride field
+ // and setting that stride to a cache stride.
+ Type i16 = rewriter.getI16Type();
+ Value stride;
+ if (chipset.majorVersion == 9 && chipset >= kGfx940 && cacheSwizzleStride) {
+ Value cacheStrideZext =
+ rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
+ Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
+ loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+ stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
+ /*isDisjoint=*/true);
+ } else {
+ stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
+ rewriter.getI16IntegerAttr(0));
+ }
+ // Get the number of elements.
+ // Flag word:
+ // bits 0-11: dst sel, ignored by these intrinsics
+ // bits 12-14: data format (ignored, must be nonzero, 7=float)
+ // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
+ // bit 19: In nested heap (0 here)
+ // bit 20: Behavior on unmap (0 means "return 0 / ignore")
+ // bits 21-22: Index stride for swizzles (N/A)
+ // bit 23: Add thread ID (0)
+ // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
+ // bits 25-26: Reserved (0)
+ // bit 27: Buffer is non-volatile (CDNA only)
+ // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
+ // none, 3 = either swizzles or testing against offset field) RDNA only
+ // bits 30-31: Type (must be 0)
+ uint32_t flags = (7 << 12) | (4 << 15);
+ if (chipset.majorVersion >= 10) {
+ flags |= (1 << 24);
+ uint32_t oob = boundsCheck ? 3 : 2;
+ flags |= (oob << 28);
+ }
+ Value flagsConst = createI32Constant(rewriter, loc, flags);
+ Type rsrcType =
+ LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
+ Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
+ loc, rsrcType, basePointer, stride, numRecords, flagsConst);
+ return resource;
+}
+
namespace {
-// Define commonly used chipsets versions for convenience.
-constexpr Chipset kGfx908 = Chipset(9, 0, 8);
-constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
-constexpr Chipset kGfx942 = Chipset(9, 4, 2);
+struct FatRawBufferCastLowering
+ : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
+ FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
+ chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value memRef = adaptor.getSource();
+ Value unconvertedMemref = op.getSource();
+ MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
+ MemRefDescriptor descriptor(memRef);
+
+ DataLayout dataLayout = DataLayout::closest(op);
+ int64_t elementByteWidth =
+ dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
+
+ int64_t unusedOffset = 0;
+ SmallVector<int64_t, 5> strideVals;
+ if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
+ return op.emitOpError("Can't lower non-stride-offset memrefs");
+
+ Value numRecords = adaptor.getValidBytes();
+ if (!numRecords)
+ numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
+ strideVals, elementByteWidth);
+
+ Value basePointer =
+ adaptor.getResetOffset()
+ ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
+ memrefType)
+ : descriptor.alignedPtr(rewriter, loc);
+
+ Value offset = adaptor.getResetOffset()
+ ? rewriter.create<LLVM::ConstantOp>(
+ loc, getIndexType(), rewriter.getIndexAttr(0))
+ : descriptor.offset(rewriter, loc);
+
+ bool hasSizes = memrefType.getRank() > 0;
+ // No need to unpack() and pack() all the individual sizes and strides,
+ // so we'll just extract the arrays.
+ Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
+ loc, descriptor, kSizePosInMemRefDescriptor)
+ : Value{};
+ Value strides = hasSizes
+ ? rewriter.create<LLVM::ExtractValueOp>(
+ loc, descriptor, kStridePosInMemRefDescriptor)
+ : Value{};
+
+ Value fatPtr = makeBufferRsrc(
+ rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
+ chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
+
+ Value result = MemRefDescriptor::poison(
+ rewriter, loc,
+ getTypeConverter()->convertType(op.getResult().getType()));
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
+ kOffsetPosInMemRefDescriptor);
+ if (hasSizes) {
+ result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
+ kSizePosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, result, strides, kStridePosInMemRefDescriptor);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
@@ -122,7 +284,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
Type i32 = rewriter.getI32Type();
- Type i16 = rewriter.getI16Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -199,60 +360,10 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Value ptr = memrefDescriptor.bufferPtr(
rewriter, loc, *this->getTypeConverter(), memrefType);
- // The stride value is always 0 for raw buffers. This also disables
- // swizling.
- Value stride = rewriter.create<LLVM::ConstantOp>(
- loc, i16, rewriter.getI16IntegerAttr(0));
- // Get the number of elements.
- Value numRecords;
- if (memrefType.hasStaticShape() &&
- !llvm::any_of(strides, ShapedType::isDynamic)) {
- int64_t size = memrefType.getRank() == 0 ? 1 : 0;
- ArrayRef<int64_t> shape = memrefType.getShape();
- for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
- size = std::max(shape[i] * strides[i], size);
- size = size * elementByteWidth;
- assert(size < std::numeric_limits<uint32_t>::max() &&
- "the memref buffer is too large");
- numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
- } else {
- Value maxIndex;
- for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
- Value size = memrefDescriptor.size(rewriter, loc, i);
- Value stride = memrefDescriptor.stride(rewriter, loc, i);
- Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
- maxIndex =
- maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
- : maxThisDim;
- }
- numRecords = rewriter.create<LLVM::MulOp>(
- loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
- }
-
- // Flag word:
- // bits 0-11: dst sel, ignored by these intrinsics
- // bits 12-14: data format (ignored, must be nonzero, 7=float)
- // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
- // bit 19: In nested heap (0 here)
- // bit 20: Behavior on unmap (0 means "return 0 / ignore")
- // bits 21-22: Index stride for swizzles (N/A)
- // bit 23: Add thread ID (0)
- // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
- // bits 25-26: Reserved (0)
- // bit 27: Buffer is non-volatile (CDNA only)
- // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
- // none, 3 = either swizzles or testing against offset field) RDNA only
- // bits 30-31: Type (must be 0)
- uint32_t flags = (7 << 12) | (4 << 15);
- if (chipset.majorVersion >= 10) {
- flags |= (1 << 24);
- uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
- flags |= (oob << 28);
- }
- Value flagsConst = createI32Constant(rewriter, loc, flags);
- Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
- Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
- loc, rsrcType, ptr, stride, numRecords, flagsConst);
+ Value numRecords = getNumRecords(
+ rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
+ Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
+ adaptor.getBoundsCheck(), chipset);
args.push_back(resource);
// Indexing (voffset)
@@ -1062,11 +1173,32 @@ struct ConvertAMDGPUToROCDLPass
};
} // namespace
-void mlir::populateAMDGPUToROCDLConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- Chipset chipset) {
+void mlir::populateAMDGPUMemorySpaceAttributeConversions(
+ TypeConverter &typeConverter) {
+ typeConverter.addTypeAttributeConversion(
+ [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
+ -> TypeConverter::AttributeConversionResult {
+ MLIRContext *ctx = as.getContext();
+ Type i64 = IntegerType::get(ctx, 64);
+ switch (as.getValue()) {
+ case amdgpu::AddressSpace::FatRawBuffer:
+ return IntegerAttr::get(i64, 7);
+ case amdgpu::AddressSpace::BufferRsrc:
+ return IntegerAttr::get(i64, 8);
+ case amdgpu::AddressSpace::FatStructuredBuffer:
+ return IntegerAttr::get(i64, 9);
+ }
+ return TypeConverter::AttributeConversionResult::abort();
+ });
+}
+
+void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ Chipset chipset) {
+ populateAMDGPUMemorySpaceAttributeConversions(converter);
patterns
- .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
+ .add<FatRawBufferCastLowering,
+ RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
RawBufferOpLowering<RawBufferAtomicFaddOp,
ROCDL::RawPtrBufferAtomicFaddOp>,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 271ca382e2f0b..d2bfb863244d9 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -59,6 +59,59 @@ LogicalResult PackedStochRoundFp8Op::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// FatRawBuferCastOp
+//===----------------------------------------------------------------------===//
+
+/// Convert the type `source` to one with the same sizes and strides - and
+/// offset, unless `stripOffset` is true, in which case the offset is reset to
+/// 0, if the offset should be reset but the layout of `source` isn't either the
+/// identity layout or a strided layout, this function fails.
+static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
+ bool resetOffset) {
+ MLIRContext *ctx = source.getContext();
+ MemRefType::Builder mb(source);
+ mb.setMemorySpace(
+ amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
+ MemRefLayoutAttrInterface layout = source.getLayout();
+ if (resetOffset && !layout.isIdentity()) {
+ auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
+ if (!stridedLayout)
+ return failure();
+ mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
+ }
+ return (MemRefType)(mb);
+}
+
+LogicalResult FatRawBufferCastOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ Adaptor adaptor(operands, attributes, properties, regions);
+ auto sourceType =
+ dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
+ if (!sourceType)
+ return failure();
+ FailureOr<MemRefType> resultType =
+ getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
+ if (failed(resultType))
+ return failure();
+ inferredReturnTypes = SmallVector<Type>{*resultType};
+ return success();
+}
+
+LogicalResult FatRawBufferCastOp::verify() {
+ FailureOr<MemRefType> expectedResultType =
+ getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
+ if (failed(expectedResultType))
+ return emitOpError("source type ")
+ << getSource().getType() << " can't have its offset reset";
+ if (getResult().getType() != *expectedResultType)
+ return emitOpError("expected result type to be ")
+ << *expectedResultType << " but got " << getResult().getType();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// RawBuffer*Op
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 5f934714d988a..3d4567bff1e32 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRAMDGPUTransforms
EmulateAtomics.cpp
+ ResolveStridedMetadata.cpp
ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
@@ -11,6 +12,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
MLIRAMDGPUDialect
MLIRAMDGPUUtils
MLIRArithDialect
+ MLIRMemRefDialect
MLIRVectorDialect
MLIRControlFlowDialect
MLIRFuncDialect
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
new file mode 100644
index 0000000000000..4b3d94b4ce2ad
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
@@ -0,0 +1,79 @@
+//===- ResolveStridedMetadata.cpp - AMDGPU expand_strided_metadata ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPURESOLVESTRIDEDMETADATAPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace mlir::amdgpu
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+namespace {
+struct AmdgpuResolveStridedMetadataPass
+ : public amdgpu::impl::AmdgpuResolveStridedMetadataPassBase<
+ AmdgpuResolveStridedMetadataPass> {
+ void runOnOperation() override;
+};
+
+struct ExtractStridedMetadataOnFatRawBufferCastFolder final
+ : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp,
+ PatternRewriter &rewriter) const override {
+ auto castOp = metadataOp.getSource().getDefiningOp<FatRawBufferCastOp>();
+ if (!castOp)
+ return rewriter.notifyMatchFailure(metadataOp,
+ "not a fat raw buffer cast");
+ Location loc = castOp.getLoc();
+ auto sourceMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, castOp.getSource());
+ SmallVector<Value> results;
+ if (metadataOp.getBaseBuffer().use_empty()) {
+ results.push_back(nullptr);
+ } else {
+ auto baseBufferType =
+ cast<MemRefType>(metadataOp.getBaseBuffer().getType());
+ if (baseBufferType == castOp.getResult().getType()) {
+ results.push_back(castOp.getResult());
+ } else {
+ results.push_back(rewriter.create<memref::ReinterpretCastOp>(
+ loc, baseBufferType, castOp.getResult(), /*offset=*/0,
+ /*sizes=*/ArrayRef<int64_t>{}, /*strides=*/ArrayRef<int64_t>{}));
+ }
+ }
+ if (castOp.getResetOffset())
+ results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ else
+ results.push_back(sourceMetadata.getOffset());
+ llvm::append_range(results, sourceMetadata.getSizes());
+ llvm::append_range(results, sourceMetadata.getStrides());
+ rewriter.replaceOp(metadataOp, results);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::amdgpu::populateAmdgpuResolveStridedMetadataPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ExtractStridedMetadataOnFatRawBufferCastFolder>(
+ patterns.getContext());
+}
+
+void AmdgpuResolveStridedMetadataPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateAmdgpuResolveStridedMetadataPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 062b63c076c3c..ae1b34ef3f8eb 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -1,13 +1,124 @@
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx908 | FileCheck %s --check-prefixes=CHECK,GFX9,GFX908
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx90a | FileCheck %s --check-prefixes=CHECK,GFX9,GFX90A
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9,GFX942
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10,RDNA
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11,RDNA
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12,RDNA
+// Note: #gpu.address_space<global> is hardcoded to `1` here because the
+// test pass doesn't set up the GPU address space conversions.
+
+#gpu_global_addrspace = 1
+
+// CHECK-LABEL: func @fat_raw_buffer_cast
+func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<8xi32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[base:.*]] = llvm.extractvalue %[[desc]][1]
+ // CHECK-DAG: %[[offset:.*]] = llvm.extractvalue %[[desc]][2]
+ // CHECK-DAG: %[[sizes:.*]] = llvm.extractvalue %[[desc]][3]
+ // CHECK-DAG: %[[strides:.*]] = llvm.extractvalue %[[desc]][4]
+ // CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(32 : i32) : i32
+ // CHECK-DAG: %[[strideArg:.*]] = llvm.mlir.constant(0 : i16) : i16
+ // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+ // RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+ // CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[base]], %[[strideArg]], %[[numRecords]], %[[flags]] : <1> to <7>
+ // CHECK: %[[ret0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<7>, ptr<7>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[ret1:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret0]][0]
+ // CHECK: %[[ret2:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret1]][1]
+ // CHECK: %[[ret3:.*]] = llvm.insertvalue %[[offset]], %[[ret2]][2]
+ // CHECK: %[[ret4:.*]] = llvm.insertvalue %[[sizes]], %[[ret3]][3]
+ // CHECK: %[[ret5:.*]] = llvm.insertvalue %[[strides]], %[[ret4]][4]
+ // CHECK: builtin.unrealized_conversion_cast %[[ret5]]
+ %ret = amdgpu.fat_raw_buffer_cast %buf : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_0d
+func.func @fat_raw_buffer_cast_0d(%buf: memref<i32, #gpu_global_addrspace>) -> memref<i32, #amdgpu.address_space<fat_raw_buffer>> {
+ // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<i32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64)>
+ // CHECK-DAG: %[[base:.*]] = llvm.extractvalue %[[desc]][1]
+ // CHECK-DAG: %[[offset:.*]] = llvm.extractvalue %[[desc]][2]
+ // CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK-DAG: %[[strideArg:.*]] = llvm.mlir.constant(0 : i16) : i16
+ // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+ // RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+ // CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[base]], %[[strideArg]], %[[numRecords]], %[[flags]]
+ // CHECK: %[[ret0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<7>, ptr<7>, i64)>
+ // CHECK: %[[ret1:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret0]][0]
+ // CHECK: %[[ret2:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret1]][1]
+ // CHECK: %[[ret3:.*]] = llvm.insertvalue %[[offset]], %[[ret2]][2]
+ // CHECK: builtin.unrealized_conversion_cast %[[ret3]]
+ %ret = amdgpu.fat_raw_buffer_cast %buf : memref<i32, #gpu_global_addrspace> to memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dyn_size_offset
+func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> {
+ // CHECK: %[[size0:.*]] = llvm.extractvalue %{{.*}}[3, 0]
+ // CHECK: %[[stride0:.*]] = llvm.extractvalue %{{.*}}[4, 0]
+ // CHECK: %[[maxVals:.*]] = llvm.mul %[[size0]], %[[stride0]]
+ // CHECK: %[[maxValsI32:.*]] = llvm.trunc %[[maxVals]] : i64 to i32
+ // CHECK: %[[byteSize:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: %[[numRecords:.*]] = llvm.mul %[[maxValsI32]], %[[byteSize]]
+ // CHECK: %[[offset:.*]] = llvm.extractvalue %{{.*}}[2]
+ // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
+ // CHECK: llvm.insertvalue %[[offset]], %{{.*}}[2]
+ %ret = amdgpu.fat_raw_buffer_cast %buf : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_reset_offset
+func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+ // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<?xi32, strided<[1], offset: ?>, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[memRefPtr:.*]] = llvm.extractvalue %[[desc]][1]
+ // CHECK-DAG: %[[memRefOff:.*]] = llvm.extractvalue %[[desc]][2]
+ // CHECK-DAG: %[[basePtr:.*]] = llvm.getelementptr %[[memRefPtr]][%[[memRefOff]]]
+ // CHECK-DAG: %[[zeroOff:.*]] = llvm.mlir.constant(0 : index) : i64
+ // CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[basePtr]], %{{.*}}, %{{.*}}, %{{.*}}
+ // CHECK: llvm.insertvalue %[[fatBuf]], %{{.*}}[1]
+ // CHECK: llvm.insertvalue %[[zeroOff]], %{{.*}}[2]
+ %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_valid_bytes
+func.func @fat_raw_buffer_cast_valid_bytes(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ // CHECK: %[[numRecords:.*]] = arith.constant -1 : i32
+ // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
+ %cu32_max = arith.constant 0xffffffff : i32
+ %ret = amdgpu.fat_raw_buffer_cast %buf validBytes(%cu32_max) : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_bounds_check
+func.func @fat_raw_buffer_cast_bounds_check(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+ // RDNA: %[[flags:.*]] = llvm.mlir.constant(553807872 : i32)
+ // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %{{.*}}, %[[flags]]
+ %ret = amdgpu.fat_raw_buffer_cast %buf boundsCheck(false) : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_cache_swizzle
+// CHECK-SAME: (%{{.*}}: memref<64x64xi32, 1>, %[[stride:.*]]: i14)
+func.func @fat_raw_buffer_cast_cache_swizzle(%buf: memref<64x64xi32, #gpu_global_addrspace>, %stride: i14) -> memref<64x64xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ // GFX908: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
+ // GFX90A: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
+ // RDNA: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
+ // GFX942: %[[asI16:.*]] = llvm.zext %[[stride]] : i14 to i16
+ // GFX942: %[[cacheSwizzleOn:.*]] = llvm.mlir.constant(16384 : i16) : i16
+ // GFX942: %[[stride:.*]] = llvm.or disjoint %[[asI16]], %[[cacheSwizzleOn]]
+ // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %{{.*}}, %{{.*}}
+ %ret = amdgpu.fat_raw_buffer_cast %buf cacheSwizzleStride(%stride) : memref<64x64xi32, #gpu_global_addrspace> to memref<64x64xi32, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<64x64xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_scalar_i32
func.func @gpu_gcn_raw_buffer_load_scalar_i32(%buf: memref<i32>) -> i32 {
- // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
+ // Extra constant for byte width
+ // CHECK: llvm.mlir.constant(4 : i32)
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(4 : i32)
+ // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>
@@ -19,8 +130,8 @@ func.func @gpu_gcn_raw_buffer_load_scalar_i32(%buf: memref<i32>) -> i32 {
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32
func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
- // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
+ // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>
@@ -37,7 +148,6 @@ func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<16x16xi32, strided<[
// CHECK: %[[algn_ptr:.*]] = llvm.extractvalue %[[descriptor]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[offset:.*]] = llvm.extractvalue %[[descriptor]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[ptr:.*]] = llvm.getelementptr %[[algn_ptr]][%[[offset]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
- // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
// CHECK: %[[sz_i:.*]] = llvm.extractvalue %[[descriptor]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[stride_i:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[ext_i:.*]] = llvm.mul %[[sz_i]], %[[stride_i]] : i64
@@ -46,7 +156,9 @@ func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<16x16xi32, strided<[
// CHECK: %[[ext_j:.*]] = llvm.mul %[[sz_j]], %[[stride_j]] : i64
// CHECK: %[[num_records:.*]] = llvm.intr.umax(%[[ext_i]], %[[ext_j]]) : (i64, i64) -> i64
// CHECK: %[[num_rec_i32:.*]] = llvm.trunc %[[num_records]] : i64 to i32
- // CHECK: %[[num_rec_bytes_i32:.*]] = llvm.mul %[[num_rec_i32]], %[[elem_size]] : i32
+ // CHECK: %[[elem_size_2:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: %[[num_rec_bytes_i32:.*]] = llvm.mul %[[num_rec_i32]], %[[elem_size_2]] : i32
+ // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
// CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %[[stride]], %[[num_rec_bytes_i32]], %{{.*}} : !llvm.ptr to <8>
// CHECK: %[[stride_i_1:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[stride_i_i32:.*]] = llvm.trunc %[[stride_i_1]] : i64 to i32
@@ -289,6 +401,8 @@ func.func @lds_barrier() {
// GFX908-SAME: ";;;WARNING: BREAKS DEBUG WATCHES\0As_waitcnt lgkmcnt(0)\0As_barrier"
// GFX90A: rocdl.s.waitcnt -7937
// GFX90A-NEXT: rocdl.s.barrier
+ // GFX942: rocdl.s.waitcnt -7937
+ // GFX942-NEXT: rocdl.s.barrier
// GFX10: rocdl.s.waitcnt -16129
// GFX10-NEXT: rocdl.s.barrier
// GFX11: llvm.inline_asm has_side_effects asm_dialect = att
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-resolve-strided-metadata.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-resolve-strided-metadata.mlir
new file mode 100644
index 0000000000000..831bb5f0f66ec
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-resolve-strided-metadata.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt -amdgpu-resolve-strided-metadata -split-input-file %s | FileCheck %s
+
+!tSrc = memref<?x?xi32, strided<[?, ?], offset: ?>>
+!tDst = memref<?x?xi32, strided<[?, ?], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>
+!tRes = memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-LABEL: @resolve_metadata_no_offset_reset
+// CHECK-SAME: (%[[arg0:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
+// CHECK-NEXT: %[[cast:.+]] = amdgpu.fat_raw_buffer_cast %[[arg0]]
+// CHECK-NEXT: %{{.+}}, %[[offset:.+]], %[[size:.+]]:2, %[[stride:.+]]:2 = memref.extract_strided_metadata %[[arg0]]
+// CHECK-NEXT: %[[reinterp:.+]] = memref.reinterpret_cast %[[cast]]
+// CHECK-NEXT: return %[[reinterp]], %[[offset]], %[[size]]#0, %[[size]]#1, %[[stride]]#0, %[[stride]]#1
+func.func @resolve_metadata_no_offset_reset(%arg0: !tSrc) -> (!tRes, index, index, index, index, index) {
+ %cast = amdgpu.fat_raw_buffer_cast %arg0 : !tSrc to !tDst
+ %base, %offset, %size:2, %stride:2 = memref.extract_strided_metadata %cast : !tDst -> !tRes, index, index, index, index, index
+ func.return %base, %offset, %size#0, %size#1, %stride#0, %stride#1 : !tRes, index, index, index, index, index
+}
+
+// -----
+
+!tSrc = memref<?x?xi32, strided<[?, ?], offset: ?>>
+!tDst = memref<?x?xi32, strided<[?, ?]>, #amdgpu.address_space<fat_raw_buffer>>
+!tRes = memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-LABEL: @resolve_metadata_offset_reset
+// CHECK-SAME: (%[[arg0:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
+// CHECK-NEXT: %[[offset:.+]] = arith.constant 0 : index
+// CHECK-NEXT: %[[cast:.+]] = amdgpu.fat_raw_buffer_cast %[[arg0]]
+// CHECK-NEXT: %{{.+}}, %{{.+}}, %[[size:.+]]:2, %[[stride:.+]]:2 = memref.extract_strided_metadata %[[arg0]]
+// CHECK-NEXT: %[[reinterp:.+]] = memref.reinterpret_cast %[[cast]]
+// CHECK-NEXT: return %[[reinterp]], %[[offset]], %[[size]]#0, %[[size]]#1, %[[stride]]#0, %[[stride]]#1
+func.func @resolve_metadata_offset_reset(%arg0: !tSrc) -> (!tRes, index, index, index, index, index) {
+ %cast = amdgpu.fat_raw_buffer_cast %arg0 resetOffset : !tSrc to !tDst
+ %base, %offset, %size:2, %stride:2 = memref.extract_strided_metadata %cast : !tDst -> !tRes, index, index, index, index, index
+ func.return %base, %offset, %size#0, %size#1, %stride#0, %stride#1 : !tRes, index, index, index, index, index
+}
+
+// -----
+
+!tSrc = memref<?x?xi32, strided<[?, ?], offset: ?>>
+!tDst = memref<?x?xi32, strided<[?, ?]>, #amdgpu.address_space<fat_raw_buffer>>
+!tRes = memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-LABEL: @resolve_metadata_no_base_ptr
+// CHECK-SAME: (%[[arg0:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
+// CHECK-NEXT: %[[offset:.+]] = arith.constant 0 : index
+// CHECK-NEXT: %[[cast:.+]] = amdgpu.fat_raw_buffer_cast %[[arg0]]
+// CHECK-NEXT: %{{.+}}, %{{.+}}, %[[size:.+]]:2, %[[stride:.+]]:2 = memref.extract_strided_metadata %[[arg0]]
+// CHECK-NEXT: return %[[cast]], %[[offset]], %[[size]]#0, %[[size]]#1, %[[stride]]#0, %[[stride]]#1
+func.func @resolve_metadata_no_base_ptr(%arg0: !tSrc) -> (!tDst, index, index, index, index, index) {
+ %cast = amdgpu.fat_raw_buffer_cast %arg0 resetOffset : !tSrc to !tDst
+ %base, %offset, %size:2, %stride:2 = memref.extract_strided_metadata %cast : !tDst -> !tRes, index, index, index, index, index
+ func.return %cast, %offset, %size#0, %size#1, %stride#0, %stride#1 : !tDst, index, index, index, index, index
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 5e1ab79962d2f..7cb16f5259070 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -125,3 +125,28 @@ func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32>
%0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32>
func.return %0 : vector<8xi32>
}
+
+// -----
+
+// Missinng `resetOffset`
+func.func @fat_raw_buffer_cast_stripped_offset(%m: memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ // expected-error at +1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>' but got 'memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>'}}
+ %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// -----
+
+func.func @fat_raw_buffer_cast_wrong_as(%m: memref<8xi32>) -> memref<8xi32, #amdgpu.address_space<buffer_rsrc>> {
+ // expected-error at +1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>' but got 'memref<8xi32, #amdgpu.address_space<buffer_rsrc>>'}}
+ %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32> to memref<8xi32, #amdgpu.address_space<buffer_rsrc>>
+ return %ret : memref<8xi32, #amdgpu.address_space<buffer_rsrc>>
+}
+
+// -----
+
+func.func @fat_raw_buffer_cast_stripping_offset_affine_map(%m: memref<8xi32, affine_map<(d0)[s0] -> (d0 + s0)>>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ // expected-error at +1 {{'amdgpu.fat_raw_buffer_cast' op source type 'memref<8xi32, affine_map<(d0)[s0] -> (d0 + s0)>>' can't have its offset reset}}
+ %ret = amdgpu.fat_raw_buffer_cast %m resetOffset : memref<8xi32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 9457a1b9e4498..567e6498330a3 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -25,6 +25,25 @@ func.func @packed_stoch_round_fp8(%v1: f32, %stoch: i32, %others: vector<4xf8E5M
func.return %ret : vector<4xf8E5M2FNUZ>
}
+// CHECK-LABEL: func @fat_raw_buffer_cast_easy
+// CHECK: amdgpu.fat_raw_buffer_cast
+func.func @fat_raw_buffer_cast_easy(%m: memref<8xi32>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast
+// CHECK: amdgpu.fat_raw_buffer_cast
+// CHECK-SAME: validBytes(%{{[^)]*}})
+// CHECK-SAME: cacheSwizzleStride(%{{[^)]*}})
+// CHECK-SAME: boundsCheck(false)
+// CHECK-SAME: resetOffset
+func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m validBytes(%validBytes) cacheSwizzleStride(%cacheSwizzle) boundsCheck(false) resetOffset
+ : memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
// CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1
func.func @raw_buffer_load_f32_from_rank_1(%src : memref<128xf32>, %offset : i32, %idx0 : i32) -> f32 {
// CHECK: amdgpu.raw_buffer_load {indexOffset = 1 : i32} %{{.*}}[{{.*}}] sgprOffset %{{.*}} : memref<128xf32>, i32 -> f32
More information about the Mlir-commits
mailing list