[Mlir-commits] [mlir] [mlir][AMDGPU] Plumb address space 7 through MLIR, add address_space attr. (PR #125594)
Krzysztof Drewniak
llvmlistbot at llvm.org
Tue Feb 11 12:07:57 PST 2025
https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/125594
>From dd4b1101d255deceb9036aaf3a5a3efd2d9a5745 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 3 Feb 2025 22:49:38 +0000
Subject: [PATCH 1/8] [mlir][AMDGPU] Plumb address space 7 through MLIR, add
address_space attr.
This commit adds support for casting memrefs into fat raw buffer
pointers to the AMDGPU dialect.
Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow
encapsulating a buffer descriptor (as produced by the make.buffer.rsrc
intrinsic or provided from some API) into a pointer that supports
ordinary pointer operations like load or store. This allows people to
take advantage of the additional semantics that buffer_load and
similar instructions provide without forcing the use of entirely
separate amdgpu.raw_buffer_* operations.
Operations on fat raw buffer pointers are translated to the
corresponding LLVM intrinsics by the backend.
This commit also goes and and defines a #amdgpu.address_space<>
attribute so that AMDGPU-specific memory spaces can be represented.
Only #amdgpu.address_space<fat_raw_buffer> will work correctly with
the memref dialect, but the other possible address spaces are included
for completeness.
---
.../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h | 16 +-
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 105 +++++++
.../mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h | 2 +
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 256 +++++++++++++-----
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 53 ++++
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 105 ++++++-
6 files changed, 466 insertions(+), 71 deletions(-)
diff --git a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
index e7637a6013e68..bb4e7bc037a37 100644
--- a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
+++ b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
@@ -16,18 +16,26 @@ namespace mlir {
class LLVMTypeConverter;
class RewritePatternSet;
+class TypeConverter;
class Pass;
#define GEN_PASS_DECL_CONVERTAMDGPUTOROCDL
#include "mlir/Conversion/Passes.h.inc"
-/// Note: The ROCDL target does not support the LLVM bfloat type at this time
-/// and so this function will add conversions to change all `bfloat` uses
-/// to `i16`.
-void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter,
+/// Note: This function will also add conversions for the AMDGPU-specific
+/// address spaces, but those can be added separately using
+/// populateAMDGPUMemorySpaceAttributeConversions().
+void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
amdgpu::Chipset chipset);
+/// Remap AMDGPU memory spaces to LLVM address spaces
+/// by mapping amdgpu::AddressSpace::fat_raw_buffer to ptr addrspace(7),
+/// amdgpu::AddressSpace::buffer_rsrc to ptr addrspace(8), and
+/// amdgpu::AddressSpace::fat_strided_buffer to ptr addrspace(9).
+void populateAMDGPUMemorySpaceAttributeConversions(
+ TypeConverter &typeConverter);
+
std::unique_ptr<Pass> createConvertAMDGPUToROCDLPass();
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 69745addfd748..6c42849fc71f1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -9,8 +9,11 @@
#ifndef AMDGPU
#define AMDGPU
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/EnumAttr.td"
+include "mlir/IR/Properties.td"
include "mlir/IR/OpBase.td"
def AMDGPU_Dialect : Dialect {
@@ -32,6 +35,45 @@ def AMDGPU_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
}
+//===----------------------------------------------------------------------===//
+// AMDGPU general attribute definitions
+//===----------------------------------------------------------------------===//
+
+def AMDGPU_AddressSpace : I32EnumAttr<"AddressSpace",
+ "AMDGPU-specific address spaces",
+ [
+ I32EnumAttrCase<"FatRawBuffer", 0, "fat_raw_buffer">,
+ I32EnumAttrCase<"BufferRsrc", 1, "buffer_rsrc">,
+ I32EnumAttrCase<"FatStructuredBuffer", 2, "fat_structured_buffer">,
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
+ "address_space"> {
+ let description = [{
+ AMDGPU-specific memory spaces that may not have exact analogues on other
+ GPU targets or backends.
+
+ - fat_raw_buffer is the memory space used when a memref is stored as
+ as a "buffer fat pointer" - that is, a buffer resource (that is set up to
+ use raw byte-level indexing) along with its offset. The AMDGPU backend
+ implements ptr addrspace(7) to represent these fat pointers so that
+ buffer resources (which allow advanced features like bounds checking or
+ cache swizzling) can be used like ordinary LLVM pointers or memrefs.
+ See also the fat_raw_buffer_cast operation
+ - buffer_rsrc is the memory space for ptr addrspace(8), representing a
+ buffer resource. It should not be used for memrefs, since it does not support
+ indexing
+ - fat_structured_buffer represents ptr addrspace(9), a buffer resource
+ that carries both an index and offset field, which are used for complex
+ structured indexing that is primarily seen in graphics applications. This
+ is also incompatible with the simple indexing model supported by memref.
+ }];
+ let assemblyFormat = [{ `<` $value `>` }];
+}
+
//===----------------------------------------------------------------------===//
// AMDGPU Op definitions
//===----------------------------------------------------------------------===//
@@ -118,6 +160,69 @@ def AMDGPU_PackedStochRoundFp8Op :
let hasVerifier = 1;
}
+def AMDGPU_FatRawBufferCastOp :
+ AMDGPU_Op<"fat_raw_buffer_cast",
+ [Pure,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>,
+ ViewLikeOpInterface, AttrSizedOperandSegments]>,
+ Arguments<(ins AnyMemRef:$source,
+ Optional<I32>:$validBytes,
+ Optional<I<14>>:$cacheSwizzleStride,
+ DefaultValuedProp<BoolProp, "true">:$boundsCheck,
+ UnitProp:$resetOffset)>,
+ Results<(outs AnyMemRef:$result)> {
+ let summary = "Create a raw buffer fat pointer that matches `memref`";
+ let description = [{
+ Wraps the memory pointed to by `in` as a raw buffer fat pointer, or,
+ in LLVM terms, a ptr addrspace(7), returning a memref that has the same
+ sizes and layout but the `#amdgpu.address_space<fat_raw_buffer>`
+ address space.
+
+ This memref can be used with standard memref operations like `memref.load`,
+ `memref.store`, and `memref.atomicrmw`, which will be lowered to the relevant
+ buffer intrinsics. (`vector.masked_load/store` will work once there's backend
+ support for lowering them, and then this document will be updated)
+
+ If `validBytes` is given, it is the number of bytes that will be valid as
+ an offset to `out`. If it is not provided, this will be inferred from
+ the size of the memref during lowering. This size is
+ max_d (sizes[d] * strides[d]) * sizeof(element type)..
+
+ The flags of the buffer descriptor will be set up to enable raw usage -
+ for example, stride = 0, add_tid = 0, and so on. The `boundsCheck`
+ property determines if bounds checking is enabled or not (on architectures
+ where this can be controlled - that is, on RDNA chips).
+
+ If `cacheSwizzleStride` is provided, L1 cache swizzling will be enabled
+ on architectures that support it. This swizzling, unlike the main swizzling
+ mode (whose usage makes a buffer non-raw) does not affect index calculaton,
+ but does affect cache behavior. Mixing access between cache-swizzled raw
+ buffers and other forms of memory access, like ordinary pointer loads or
+ unswizzled buffer pointers can cause incorrect behavior and must be avoided.
+
+ This operation preserves the sizes, strides, and offset of the input
+ memref - they'll be added in by `memref.load` later. However, if
+ `resetOffset` is set, that offset will be added to the base pointer.
+ If the value of the memref's offset is not independent of the lane/thread ID,
+ this will lead to substantially decreased performance due to the need for
+ a waterfall loop on the base address of the buffer resource.
+ }];
+
+ let extraClassDeclaration = [{
+ Value getViewSource() { return getSource(); }
+ }];
+
+ let assemblyFormat = [{
+ $source oilist (`validBytes` `(` $validBytes `)`
+ | `cacheSwizzleStride` `(` $cacheSwizzleStride `)`
+ | `boundsCheck` `(` $boundsCheck `)`
+ | `resetOffset` $resetOffset )
+ attr-dict `:` type($source) `to` type($result)
+ }];
+
+ let hasVerifier = 1;
+}
+
/// Raw buffer load
def AMDGPU_RawBufferLoadOp :
AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 0a2e6bb5e9fe4..3de57c923178a 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -18,7 +18,9 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc"
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 9fb51f0bc1f1e..173b0b612ca0d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -19,6 +19,8 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
+#include "../LLVMCommon/MemRefDescriptor.h"
+
#include "llvm/ADT/STLExtras.h"
#include <optional>
@@ -30,6 +32,11 @@ namespace mlir {
using namespace mlir;
using namespace mlir::amdgpu;
+// Define commonly used chipsets versions for convenience.
+static constexpr Chipset kGfx908 = Chipset(9, 0, 8);
+static constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
+static constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+
/// Convert an unsigned number `val` to i32.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
Location loc, Value val) {
@@ -76,11 +83,164 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
return index ? index : createI32Constant(rewriter, loc, 0);
}
+/// Compute the contents of the `num_records` field for a given memref
+/// descriptor - that is, the number of bytes that's one element past the
+/// greatest possible valid index into the memref.
+static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
+ MemRefType memrefType,
+ MemRefDescriptor &memrefDescriptor,
+ ArrayRef<int64_t> strides,
+ uint32_t elementByteWidth) {
+ if (memrefType.hasStaticShape() &&
+ !llvm::any_of(strides, ShapedType::isDynamic)) {
+ int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+ ArrayRef<int64_t> shape = memrefType.getShape();
+ for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+ size = std::max(shape[i] * strides[i], size);
+ size = size * elementByteWidth;
+ assert(size < std::numeric_limits<uint32_t>::max() &&
+ "the memref buffer is too large");
+ return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
+ }
+ Value maxIndex;
+ for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
+ Value size = memrefDescriptor.size(rewriter, loc, i);
+ Value stride = memrefDescriptor.stride(rewriter, loc, i);
+ Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+ maxIndex = maxIndex
+ ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+ : maxThisDim;
+ }
+ return rewriter.create<LLVM::MulOp>(
+ loc, convertUnsignedToI32(rewriter, loc, maxIndex),
+ createI32Constant(rewriter, loc, elementByteWidth));
+}
+
+static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
+ Value basePointer, Value numRecords,
+ bool boundsCheck, amdgpu::Chipset chipset,
+ Value cacheSwizzleStride = nullptr) {
+ // The stride value is generally 0. However, on MI-300 and onward, you can
+ // enable a cache swizzling mode by setting bit 14 of the stride field
+ // and setting that stride to a cache stride.
+ Type i16 = rewriter.getI16Type();
+ Value stride;
+ if (chipset.majorVersion == 9 && chipset >= kGfx940 && cacheSwizzleStride) {
+ Value cacheStrideZext =
+ rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
+ Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
+ loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+ stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
+ /*isDisjoint=*/true);
+ } else {
+ stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
+ rewriter.getI16IntegerAttr(0));
+ }
+ // Get the number of elements.
+ // Flag word:
+ // bits 0-11: dst sel, ignored by these intrinsics
+ // bits 12-14: data format (ignored, must be nonzero, 7=float)
+ // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
+ // bit 19: In nested heap (0 here)
+ // bit 20: Behavior on unmap (0 means "return 0 / ignore")
+ // bits 21-22: Index stride for swizzles (N/A)
+ // bit 23: Add thread ID (0)
+ // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
+ // bits 25-26: Reserved (0)
+ // bit 27: Buffer is non-volatile (CDNA only)
+ // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
+ // none, 3 = either swizzles or testing against offset field) RDNA only
+ // bits 30-31: Type (must be 0)
+ uint32_t flags = (7 << 12) | (4 << 15);
+ if (chipset.majorVersion >= 10) {
+ flags |= (1 << 24);
+ uint32_t oob = boundsCheck ? 3 : 2;
+ flags |= (oob << 28);
+ }
+ Value flagsConst = createI32Constant(rewriter, loc, flags);
+ Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
+ Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
+ loc, rsrcType, basePointer, stride, numRecords, flagsConst);
+ return resource;
+}
+
namespace {
-// Define commonly used chipsets versions for convenience.
-constexpr Chipset kGfx908 = Chipset(9, 0, 8);
-constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
-constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+struct FatRawBufferCastLowering
+ : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
+ FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
+ chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value memRef = adaptor.getSource();
+ Value unconvertedMemref = op.getSource();
+ MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
+ MemRefDescriptor descriptor(memRef);
+
+ DataLayout dataLayout = DataLayout::closest(op);
+ int64_t elementByteWidth =
+ dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
+
+ int64_t unusedOffset = 0;
+ SmallVector<int64_t, 5> strideVals;
+ if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
+ return op.emitOpError("Can't lower non-stride-offset memrefs");
+
+ Value numRecords = adaptor.getValidBytes();
+ if (!numRecords)
+ numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
+ strideVals, elementByteWidth);
+
+ Value basePointer;
+ if (adaptor.getResetOffset())
+ basePointer =
+ descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), memrefType);
+ else
+ basePointer = descriptor.alignedPtr(rewriter, loc);
+
+ Value offset;
+ if (adaptor.getResetOffset())
+ offset = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+ rewriter.getIndexAttr(0));
+ else
+ offset = descriptor.offset(rewriter, loc);
+
+ // No need to unpack() and pack() all the individual sizes and strides,
+ // so we'll just extract the arrays.
+ Value sizes = rewriter.create<LLVM::ExtractValueOp>(
+ loc, descriptor, kSizePosInMemRefDescriptor);
+ Value strides = rewriter.create<LLVM::ExtractValueOp>(
+ loc, descriptor, kStridePosInMemRefDescriptor);
+
+ Value rsrc = makeBufferRsrc(rewriter, loc, basePointer, numRecords,
+ adaptor.getBoundsCheck(), chipset,
+ adaptor.getCacheSwizzleStride());
+ Value fatPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+ loc, LLVM::LLVMPointerType::get(op.getContext(), 7), rsrc);
+
+ Value result = MemRefDescriptor::undef(
+ rewriter, loc,
+ getTypeConverter()->convertType(op.getResult().getType()));
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
+ kOffsetPosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
+ kSizePosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(loc, result, strides,
+ kStridePosInMemRefDescriptor);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
@@ -122,7 +282,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
Type i32 = rewriter.getI32Type();
- Type i16 = rewriter.getI16Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -199,60 +358,10 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Value ptr = memrefDescriptor.bufferPtr(
rewriter, loc, *this->getTypeConverter(), memrefType);
- // The stride value is always 0 for raw buffers. This also disables
- // swizling.
- Value stride = rewriter.create<LLVM::ConstantOp>(
- loc, i16, rewriter.getI16IntegerAttr(0));
- // Get the number of elements.
- Value numRecords;
- if (memrefType.hasStaticShape() &&
- !llvm::any_of(strides, ShapedType::isDynamic)) {
- int64_t size = memrefType.getRank() == 0 ? 1 : 0;
- ArrayRef<int64_t> shape = memrefType.getShape();
- for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
- size = std::max(shape[i] * strides[i], size);
- size = size * elementByteWidth;
- assert(size < std::numeric_limits<uint32_t>::max() &&
- "the memref buffer is too large");
- numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
- } else {
- Value maxIndex;
- for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
- Value size = memrefDescriptor.size(rewriter, loc, i);
- Value stride = memrefDescriptor.stride(rewriter, loc, i);
- Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
- maxIndex =
- maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
- : maxThisDim;
- }
- numRecords = rewriter.create<LLVM::MulOp>(
- loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
- }
-
- // Flag word:
- // bits 0-11: dst sel, ignored by these intrinsics
- // bits 12-14: data format (ignored, must be nonzero, 7=float)
- // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
- // bit 19: In nested heap (0 here)
- // bit 20: Behavior on unmap (0 means "return 0 / ignore")
- // bits 21-22: Index stride for swizzles (N/A)
- // bit 23: Add thread ID (0)
- // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
- // bits 25-26: Reserved (0)
- // bit 27: Buffer is non-volatile (CDNA only)
- // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
- // none, 3 = either swizzles or testing against offset field) RDNA only
- // bits 30-31: Type (must be 0)
- uint32_t flags = (7 << 12) | (4 << 15);
- if (chipset.majorVersion >= 10) {
- flags |= (1 << 24);
- uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
- flags |= (oob << 28);
- }
- Value flagsConst = createI32Constant(rewriter, loc, flags);
- Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
- Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
- loc, rsrcType, ptr, stride, numRecords, flagsConst);
+ Value numRecords = getNumRecords(
+ rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
+ Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
+ adaptor.getBoundsCheck(), chipset);
args.push_back(resource);
// Indexing (voffset)
@@ -1062,11 +1171,32 @@ struct ConvertAMDGPUToROCDLPass
};
} // namespace
-void mlir::populateAMDGPUToROCDLConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- Chipset chipset) {
+void mlir::populateAMDGPUMemorySpaceAttributeConversions(
+ TypeConverter &typeConverter) {
+ typeConverter.addTypeAttributeConversion(
+ [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
+ -> TypeConverter::AttributeConversionResult {
+ MLIRContext *ctx = as.getContext();
+ Type i64 = IntegerType::get(ctx, 64);
+ switch (as.getValue()) {
+ case amdgpu::AddressSpace::FatRawBuffer:
+ return IntegerAttr::get(i64, 7);
+ case amdgpu::AddressSpace::BufferRsrc:
+ return IntegerAttr::get(i64, 8);
+ case amdgpu::AddressSpace::FatStructuredBuffer:
+ return IntegerAttr::get(i64, 9);
+ }
+ return TypeConverter::AttributeConversionResult::abort();
+ });
+}
+
+void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ Chipset chipset) {
+ populateAMDGPUMemorySpaceAttributeConversions(converter);
patterns
- .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
+ .add<FatRawBufferCastLowering,
+ RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
RawBufferOpLowering<RawBufferAtomicFaddOp,
ROCDL::RawPtrBufferAtomicFaddOp>,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 271ca382e2f0b..e944b6b5acb0c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -59,6 +59,59 @@ LogicalResult PackedStochRoundFp8Op::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// FatRawBuferCastOp
+//===----------------------------------------------------------------------===//
+
+/// Convert the type `source` to one with the same sizes and strides - and
+/// offset, unless `stripOffset` is true, in which case the offset is reset to
+/// 0, If the offset should be reset but the layout of `source` isn't either the
+/// identity layout or a strided layout, this function fails.
+static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
+ bool resetOffset) {
+ MLIRContext *ctx = source.getContext();
+ MemRefType::Builder mb(source);
+ mb.setMemorySpace(
+ amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
+ MemRefLayoutAttrInterface layout = source.getLayout();
+ if (resetOffset && !layout.isIdentity()) {
+ auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
+ if (!stridedLayout)
+ return failure();
+ mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
+ }
+ return (MemRefType)(mb);
+}
+
+LogicalResult FatRawBufferCastOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ Adaptor adaptor(operands, attributes, properties, regions);
+ auto sourceType =
+ dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
+ if (!sourceType)
+ return failure();
+ FailureOr<MemRefType> resultType =
+ getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
+ if (failed(resultType))
+ return failure();
+ inferredReturnTypes = SmallVector<Type>{*resultType};
+ return success();
+}
+
+LogicalResult FatRawBufferCastOp::verify() {
+ FailureOr<MemRefType> expectedResultType =
+ getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
+ if (failed(expectedResultType))
+ return emitOpError("source type ")
+ << getSource().getType() << " can't have its offset reset";
+ if (getResult().getType() != *expectedResultType)
+ return emitOpError("expected result type to be ")
+ << *expectedResultType << " but got " << getResult().getType();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// RawBuffer*Op
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 062b63c076c3c..921975862a4bd 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -1,13 +1,107 @@
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx908 | FileCheck %s --check-prefixes=CHECK,GFX9,GFX908
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx90a | FileCheck %s --check-prefixes=CHECK,GFX9,GFX90A
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9,GFX942
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10,RDNA
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11,RDNA
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12,RDNA
+// Note: #gpu.address_space<global> is hardcoded to `1` here because the
+// test pass doesn't set up the GPU address space conversions.
+
+#gpu_global_addrspace = 1
+
+// CHECK-LABEL: func @fat_raw_buffer_cast
+func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<8xi32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[base:.*]] = llvm.extractvalue %[[desc]][1]
+ // CHECK-DAG: %[[offset:.*]] = llvm.extractvalue %[[desc]][2]
+ // CHECK-DAG: %[[sizes:.*]] = llvm.extractvalue %[[desc]][3]
+ // CHECK-DAG: %[[strides:.*]] = llvm.extractvalue %[[desc]][4]
+ // CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(32 : i32) : i32
+ // CHECK-DAG: %[[strideArg:.*]] = llvm.mlir.constant(0 : i16) : i16
+ // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+ // RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+ // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[base]], %[[strideArg]], %[[numRecords]], %[[flags]]
+ // CHECK: %[[fatBuf:.*]] = llvm.addrspacecast %[[rsrc]] : !llvm.ptr<8> to !llvm.ptr<7>
+ // CHECK: %[[ret0:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<7>, ptr<7>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[ret1:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret0]][0]
+ // CHECK: %[[ret2:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret1]][1]
+ // CHECK: %[[ret3:.*]] = llvm.insertvalue %[[offset]], %[[ret2]][2]
+ // CHECK: %[[ret4:.*]] = llvm.insertvalue %[[sizes]], %[[ret3]][3]
+ // CHECK: %[[ret5:.*]] = llvm.insertvalue %[[strides]], %[[ret4]][4]
+ // CHECK: builtin.unrealized_conversion_cast %[[ret5]]
+ %ret = amdgpu.fat_raw_buffer_cast %buf : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dyn_size_offset
+func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> {
+ // CHECK: %[[size0:.*]] = llvm.extractvalue %{{.*}}[3, 0]
+ // CHECK: %[[stride0:.*]] = llvm.extractvalue %{{.*}}[4, 0]
+ // CHECK: %[[maxVals:.*]] = llvm.mul %[[size0]], %[[stride0]]
+ // CHECK: %[[maxValsI32:.*]] = llvm.trunc %[[maxVals]] : i64 to i32
+ // CHECK: %[[byteSize:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: %[[numRecords:.*]] = llvm.mul %[[maxValsI32]], %[[byteSize]]
+ // CHECK: %[[offset:.*]] = llvm.extractvalue %{{.*}}[2]
+ // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
+ // CHECK: llvm.insertvalue %[[offset]], %{{.*}}[2]
+ %ret = amdgpu.fat_raw_buffer_cast %buf : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_reset_offset
+func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+ // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<?xi32, strided<[1], offset: ?>, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[memRefPtr:.*]] = llvm.extractvalue %[[desc]][1]
+ // CHECK-DAG: %[[memRefOff:.*]] = llvm.extractvalue %[[desc]][2]
+ // CHECK-DAG: %[[basePtr:.*]] = llvm.getelementptr %[[memRefPtr]][%[[memRefOff]]]
+ // CHECK-DAG: %[[zeroOff:.*]] = llvm.mlir.constant(0 : index) : i64
+ // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[basePtr]], %{{.*}}, %{{.*}}, %{{.*}}
+ // CHECK: %[[fatBuf:.*]] = llvm.addrspacecast %[[rsrc]]
+ // CHECK: llvm.insertvalue %[[fatBuf]], %{{.*}}[1]
+ // CHECK: llvm.insertvalue %[[zeroOff]], %{{.*}}[2]
+ %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_valid_bytes
+func.func @fat_raw_buffer_cast_valid_bytes(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ // CHECK: %[[numRecords:.*]] = arith.constant -1 : i32
+ // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
+ %cu32_max = arith.constant 0xffffffff : i32
+ %ret = amdgpu.fat_raw_buffer_cast %buf validBytes(%cu32_max) : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_bounds_check
+func.func @fat_raw_buffer_cast_bounds_check(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+ // RDNA: %[[flags:.*]] = llvm.mlir.constant(553807872 : i32)
+ // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %{{.*}}, %[[flags]]
+ %ret = amdgpu.fat_raw_buffer_cast %buf boundsCheck(false) : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_cache_swizzle
+// CHECK-SAME: (%{{.*}}: memref<64x64xi32, 1>, %[[stride:.*]]: i14)
+func.func @fat_raw_buffer_cast_cache_swizzle(%buf: memref<64x64xi32, #gpu_global_addrspace>, %stride: i14) -> memref<64x64xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ // GFX908: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
+ // GFX90A: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
+ // RDNA: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
+ // GFX942: %[[asI16:.*]] = llvm.zext %[[stride]] : i14 to i16
+ // GFX942: %[[cacheSwizzleOn:.*]] = llvm.mlir.constant(16384 : i16) : i16
+ // GFX942: %[[stride:.*]] = llvm.or disjoint %[[asI16]], %[[cacheSwizzleOn]]
+ // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %{{.*}}, %{{.*}}
+ %ret = amdgpu.fat_raw_buffer_cast %buf cacheSwizzleStride(%stride) : memref<64x64xi32, #gpu_global_addrspace> to memref<64x64xi32, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<64x64xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_scalar_i32
func.func @gpu_gcn_raw_buffer_load_scalar_i32(%buf: memref<i32>) -> i32 {
- // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
+ // Extra constant for byte width
+ // CHECK: llvm.mlir.constant(4 : i32)
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(4 : i32)
+ // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>
@@ -19,8 +113,8 @@ func.func @gpu_gcn_raw_buffer_load_scalar_i32(%buf: memref<i32>) -> i32 {
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32
func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
- // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
+ // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>
@@ -37,7 +131,6 @@ func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<16x16xi32, strided<[
// CHECK: %[[algn_ptr:.*]] = llvm.extractvalue %[[descriptor]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[offset:.*]] = llvm.extractvalue %[[descriptor]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[ptr:.*]] = llvm.getelementptr %[[algn_ptr]][%[[offset]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
- // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
// CHECK: %[[sz_i:.*]] = llvm.extractvalue %[[descriptor]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[stride_i:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[ext_i:.*]] = llvm.mul %[[sz_i]], %[[stride_i]] : i64
@@ -46,7 +139,9 @@ func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<16x16xi32, strided<[
// CHECK: %[[ext_j:.*]] = llvm.mul %[[sz_j]], %[[stride_j]] : i64
// CHECK: %[[num_records:.*]] = llvm.intr.umax(%[[ext_i]], %[[ext_j]]) : (i64, i64) -> i64
// CHECK: %[[num_rec_i32:.*]] = llvm.trunc %[[num_records]] : i64 to i32
- // CHECK: %[[num_rec_bytes_i32:.*]] = llvm.mul %[[num_rec_i32]], %[[elem_size]] : i32
+ // CHECK: %[[elem_size_2:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: %[[num_rec_bytes_i32:.*]] = llvm.mul %[[num_rec_i32]], %[[elem_size_2]] : i32
+ // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
// CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %[[stride]], %[[num_rec_bytes_i32]], %{{.*}} : !llvm.ptr to <8>
// CHECK: %[[stride_i_1:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[stride_i_i32:.*]] = llvm.trunc %[[stride_i_1]] : i64 to i32
@@ -289,6 +384,8 @@ func.func @lds_barrier() {
// GFX908-SAME: ";;;WARNING: BREAKS DEBUG WATCHES\0As_waitcnt lgkmcnt(0)\0As_barrier"
// GFX90A: rocdl.s.waitcnt -7937
// GFX90A-NEXT: rocdl.s.barrier
+ // GFX942: rocdl.s.waitcnt -7937
+ // GFX942-NEXT: rocdl.s.barrier
// GFX10: rocdl.s.waitcnt -16129
// GFX10-NEXT: rocdl.s.barrier
// GFX11: llvm.inline_asm has_side_effects asm_dialect = att
>From 0e12ca086599130d3b90c10e76d190efa674acef Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Thu, 6 Feb 2025 04:06:30 +0000
Subject: [PATCH 2/8] Fixes to AMDGPUToROCDL PR
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 2 +-
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 23 +++++++++++--------
2 files changed, 15 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 6c42849fc71f1..c7284b4eabf05 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -71,7 +71,7 @@ def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
structured indexing that is primarily seen in graphics applications. This
is also incompatible with the simple indexing model supported by memref.
}];
- let assemblyFormat = [{ `<` $value `>` }];
+ let assemblyFormat = "`<` $value `>`";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 173b0b612ca0d..5a5fb9d48c983 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -210,12 +210,16 @@ struct FatRawBufferCastLowering
else
offset = descriptor.offset(rewriter, loc);
+ bool hasSizes = memrefType.getRank() > 0;
// No need to unpack() and pack() all the individual sizes and strides,
// so we'll just extract the arrays.
- Value sizes = rewriter.create<LLVM::ExtractValueOp>(
- loc, descriptor, kSizePosInMemRefDescriptor);
- Value strides = rewriter.create<LLVM::ExtractValueOp>(
- loc, descriptor, kStridePosInMemRefDescriptor);
+ Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
+ loc, descriptor, kSizePosInMemRefDescriptor)
+ : Value{};
+ Value strides = hasSizes
+ ? rewriter.create<LLVM::ExtractValueOp>(
+ loc, descriptor, kStridePosInMemRefDescriptor)
+ : Value{};
Value rsrc = makeBufferRsrc(rewriter, loc, basePointer, numRecords,
adaptor.getBoundsCheck(), chipset,
@@ -232,11 +236,12 @@ struct FatRawBufferCastLowering
loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
kOffsetPosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
- kSizePosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, strides,
- kStridePosInMemRefDescriptor);
-
+ if (hasSizes) {
+ result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
+ kSizePosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, result, strides, kStridePosInMemRefDescriptor);
+ }
rewriter.replaceOp(op, result);
return success();
}
>From 63039ff327fc9f4e0d3936ab0ad45f1f09027de3 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Thu, 6 Feb 2025 19:33:34 +0000
Subject: [PATCH 3/8] Fix assemply format, fix poison, add test for 0D
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 2 +-
.../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 22 ++++++++++++++++++-
2 files changed, 22 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5a5fb9d48c983..d66c44ea71a6d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -227,7 +227,7 @@ struct FatRawBufferCastLowering
Value fatPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
loc, LLVM::LLVMPointerType::get(op.getContext(), 7), rsrc);
- Value result = MemRefDescriptor::undef(
+ Value result = MemRefDescriptor::poison(
rewriter, loc,
getTypeConverter()->convertType(op.getResult().getType()));
result = rewriter.create<LLVM::InsertValueOp>(
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 921975862a4bd..fbe88dcd57cee 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -23,7 +23,7 @@ func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> me
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
// CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[base]], %[[strideArg]], %[[numRecords]], %[[flags]]
// CHECK: %[[fatBuf:.*]] = llvm.addrspacecast %[[rsrc]] : !llvm.ptr<8> to !llvm.ptr<7>
- // CHECK: %[[ret0:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<7>, ptr<7>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[ret0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<7>, ptr<7>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[ret1:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret0]][0]
// CHECK: %[[ret2:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret1]][1]
// CHECK: %[[ret3:.*]] = llvm.insertvalue %[[offset]], %[[ret2]][2]
@@ -34,6 +34,26 @@ func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> me
return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
}
+// CHECK-LABEL: func @fat_raw_buffer_cast_0d
+func.func @fat_raw_buffer_cast_0d(%buf: memref<i32, #gpu_global_addrspace>) -> memref<i32, #amdgpu.address_space<fat_raw_buffer>> {
+ // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<i32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64)>
+ // CHECK-DAG: %[[base:.*]] = llvm.extractvalue %[[desc]][1]
+ // CHECK-DAG: %[[offset:.*]] = llvm.extractvalue %[[desc]][2]
+ // CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK-DAG: %[[strideArg:.*]] = llvm.mlir.constant(0 : i16) : i16
+ // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+ // RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+ // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[base]], %[[strideArg]], %[[numRecords]], %[[flags]]
+ // CHECK: %[[fatBuf:.*]] = llvm.addrspacecast %[[rsrc]] : !llvm.ptr<8> to !llvm.ptr<7>
+ // CHECK: %[[ret0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<7>, ptr<7>, i64)>
+ // CHECK: %[[ret1:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret0]][0]
+ // CHECK: %[[ret2:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret1]][1]
+ // CHECK: %[[ret3:.*]] = llvm.insertvalue %[[offset]], %[[ret2]][2]
+ // CHECK: builtin.unrealized_conversion_cast %[[ret3]]
+ %ret = amdgpu.fat_raw_buffer_cast %buf : memref<i32, #gpu_global_addrspace> to memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
// CHECK-LABEL: func @fat_raw_buffer_cast_dyn_size_offset
func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> {
// CHECK: %[[size0:.*]] = llvm.extractvalue %{{.*}}[3, 0]
>From 9e042c8e425c6125e0984a8ae937f8d8ba2ad66c Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 10 Feb 2025 10:34:23 -0600
Subject: [PATCH 4/8] Review comments
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
Co-authored-by: Prashant Kumar <pk5561 at gmail.com>
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 6 +++---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 7 +------
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 2 +-
3 files changed, 5 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index c7284b4eabf05..56bef351e3bda 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -59,7 +59,7 @@ def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
- fat_raw_buffer is the memory space used when a memref is stored as
as a "buffer fat pointer" - that is, a buffer resource (that is set up to
use raw byte-level indexing) along with its offset. The AMDGPU backend
- implements ptr addrspace(7) to represent these fat pointers so that
+ implements `ptr addrspace(7)` to represent these fat pointers so that
buffer resources (which allow advanced features like bounds checking or
cache swizzling) can be used like ordinary LLVM pointers or memrefs.
See also the fat_raw_buffer_cast operation
@@ -195,7 +195,7 @@ def AMDGPU_FatRawBufferCastOp :
If `cacheSwizzleStride` is provided, L1 cache swizzling will be enabled
on architectures that support it. This swizzling, unlike the main swizzling
- mode (whose usage makes a buffer non-raw) does not affect index calculaton,
+ mode (whose usage makes a buffer non-raw) does not affect index calculation,
but does affect cache behavior. Mixing access between cache-swizzled raw
buffers and other forms of memory access, like ordinary pointer loads or
unswizzled buffer pointers can cause incorrect behavior and must be avoided.
@@ -203,7 +203,7 @@ def AMDGPU_FatRawBufferCastOp :
This operation preserves the sizes, strides, and offset of the input
memref - they'll be added in by `memref.load` later. However, if
`resetOffset` is set, that offset will be added to the base pointer.
- If the value of the memref's offset is not independent of the lane/thread ID,
+ If the value of the memref's offset is not uniform (independent of the lane/thread ID),
this will lead to substantially decreased performance due to the need for
a waterfall loop on the base address of the buffer resource.
}];
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index d66c44ea71a6d..cbfb45479f884 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -196,12 +196,7 @@ struct FatRawBufferCastLowering
numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
strideVals, elementByteWidth);
- Value basePointer;
- if (adaptor.getResetOffset())
- basePointer =
- descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), memrefType);
- else
- basePointer = descriptor.alignedPtr(rewriter, loc);
+Value basePointer = adaptor.getResetOffset() ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), memrefType) : descriptor.alignedPtr(rewriter, loc);
Value offset;
if (adaptor.getResetOffset())
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e944b6b5acb0c..d2bfb863244d9 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -65,7 +65,7 @@ LogicalResult PackedStochRoundFp8Op::verify() {
/// Convert the type `source` to one with the same sizes and strides - and
/// offset, unless `stripOffset` is true, in which case the offset is reset to
-/// 0, If the offset should be reset but the layout of `source` isn't either the
+/// 0, if the offset should be reset but the layout of `source` isn't either the
/// identity layout or a strided layout, this function fails.
static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
bool resetOffset) {
>From cd067b62d3d80475589caade77239f81ed9820a5 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 10 Feb 2025 17:50:19 +0000
Subject: [PATCH 5/8] Update tests, formatting
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 12 ++++-----
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 6 ++++-
mlir/test/Dialect/AMDGPU/invalid.mlir | 25 +++++++++++++++++++
mlir/test/Dialect/AMDGPU/ops.mlir | 19 ++++++++++++++
4 files changed, 55 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 56bef351e3bda..eb36cbecc5171 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -56,17 +56,17 @@ def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
AMDGPU-specific memory spaces that may not have exact analogues on other
GPU targets or backends.
- - fat_raw_buffer is the memory space used when a memref is stored as
+ - `fat_raw_buffer` is the memory space used when a memref is stored as
as a "buffer fat pointer" - that is, a buffer resource (that is set up to
use raw byte-level indexing) along with its offset. The AMDGPU backend
implements `ptr addrspace(7)` to represent these fat pointers so that
buffer resources (which allow advanced features like bounds checking or
cache swizzling) can be used like ordinary LLVM pointers or memrefs.
- See also the fat_raw_buffer_cast operation
- - buffer_rsrc is the memory space for ptr addrspace(8), representing a
+ See also the `fat_raw_buffer_cast` operation
+ - `buffer_rsrc` is the memory space for `ptr addrspace(8)`, representing a
buffer resource. It should not be used for memrefs, since it does not support
indexing
- - fat_structured_buffer represents ptr addrspace(9), a buffer resource
+ - `fat_structured_buffer` represents `ptr addrspace(9)`, a buffer resource
that carries both an index and offset field, which are used for complex
structured indexing that is primarily seen in graphics applications. This
is also incompatible with the simple indexing model supported by memref.
@@ -173,8 +173,8 @@ def AMDGPU_FatRawBufferCastOp :
Results<(outs AnyMemRef:$result)> {
let summary = "Create a raw buffer fat pointer that matches `memref`";
let description = [{
- Wraps the memory pointed to by `in` as a raw buffer fat pointer, or,
- in LLVM terms, a ptr addrspace(7), returning a memref that has the same
+ Wraps the memory pointed to by `source` as a raw buffer fat pointer, or,
+ in LLVM terms, a `ptr addrspace(7)`, returning a memref that has the same
sizes and layout but the `#amdgpu.address_space<fat_raw_buffer>`
address space.
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index cbfb45479f884..2aa153ead49be 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -196,7 +196,11 @@ struct FatRawBufferCastLowering
numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
strideVals, elementByteWidth);
-Value basePointer = adaptor.getResetOffset() ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), memrefType) : descriptor.alignedPtr(rewriter, loc);
+ Value basePointer =
+ adaptor.getResetOffset()
+ ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
+ memrefType)
+ : descriptor.alignedPtr(rewriter, loc);
Value offset;
if (adaptor.getResetOffset())
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 5e1ab79962d2f..7cb16f5259070 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -125,3 +125,28 @@ func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32>
%0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32>
func.return %0 : vector<8xi32>
}
+
+// -----
+
+// Missinng `resetOffset`
+func.func @fat_raw_buffer_cast_stripped_offset(%m: memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ // expected-error at +1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>' but got 'memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>'}}
+ %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// -----
+
+func.func @fat_raw_buffer_cast_wrong_as(%m: memref<8xi32>) -> memref<8xi32, #amdgpu.address_space<buffer_rsrc>> {
+ // expected-error at +1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>' but got 'memref<8xi32, #amdgpu.address_space<buffer_rsrc>>'}}
+ %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32> to memref<8xi32, #amdgpu.address_space<buffer_rsrc>>
+ return %ret : memref<8xi32, #amdgpu.address_space<buffer_rsrc>>
+}
+
+// -----
+
+func.func @fat_raw_buffer_cast_stripping_offset_affine_map(%m: memref<8xi32, affine_map<(d0)[s0] -> (d0 + s0)>>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ // expected-error at +1 {{'amdgpu.fat_raw_buffer_cast' op source type 'memref<8xi32, affine_map<(d0)[s0] -> (d0 + s0)>>' can't have its offset reset}}
+ %ret = amdgpu.fat_raw_buffer_cast %m resetOffset : memref<8xi32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 9457a1b9e4498..567e6498330a3 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -25,6 +25,25 @@ func.func @packed_stoch_round_fp8(%v1: f32, %stoch: i32, %others: vector<4xf8E5M
func.return %ret : vector<4xf8E5M2FNUZ>
}
+// CHECK-LABEL: func @fat_raw_buffer_cast_easy
+// CHECK: amdgpu.fat_raw_buffer_cast
+func.func @fat_raw_buffer_cast_easy(%m: memref<8xi32>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast
+// CHECK: amdgpu.fat_raw_buffer_cast
+// CHECK-SAME: validBytes(%{{[^)]*}})
+// CHECK-SAME: cacheSwizzleStride(%{{[^)]*}})
+// CHECK-SAME: boundsCheck(false)
+// CHECK-SAME: resetOffset
+func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m validBytes(%validBytes) cacheSwizzleStride(%cacheSwizzle) boundsCheck(false) resetOffset
+ : memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
// CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1
func.func @raw_buffer_load_f32_from_rank_1(%src : memref<128xf32>, %offset : i32, %idx0 : i32) -> f32 {
// CHECK: amdgpu.raw_buffer_load {indexOffset = 1 : i32} %{{.*}}[{{.*}}] sgprOffset %{{.*}} : memref<128xf32>, i32 -> f32
>From 277752273523a6892a071beee2df54b3b728b0f9 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 10 Feb 2025 11:52:05 -0600
Subject: [PATCH 6/8] Constexpr in a file doesn't have to be static, I learned
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 2aa153ead49be..24467478623eb 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -33,9 +33,9 @@ using namespace mlir;
using namespace mlir::amdgpu;
// Define commonly used chipsets versions for convenience.
-static constexpr Chipset kGfx908 = Chipset(9, 0, 8);
-static constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
-static constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+constexpr Chipset kGfx908 = Chipset(9, 0, 8);
+constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
+constexpr Chipset kGfx940 = Chipset(9, 4, 0);
/// Convert an unsigned number `val` to i32.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
>From 40436314c9de4d2eb3b53a0ba40a0a0750aacaac Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Tue, 11 Feb 2025 20:03:27 +0000
Subject: [PATCH 7/8] Fix windows test failure, move to ? operator, add strided
metadata patterns
---
.../mlir/Dialect/AMDGPU/Transforms/Passes.h | 3 +
.../mlir/Dialect/AMDGPU/Transforms/Passes.td | 16 +++++
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 15 ++--
.../Dialect/AMDGPU/Transforms/CMakeLists.txt | 2 +
.../Transforms/ResolveStridedMetadata.cpp | 68 +++++++++++++++++++
.../amdgpu-resolve-strided-metadata.mlir | 51 ++++++++++++++
6 files changed, 146 insertions(+), 9 deletions(-)
create mode 100644 mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
create mode 100644 mlir/test/Dialect/AMDGPU/amdgpu-resolve-strided-metadata.mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index 8dd5ff1a4b198..c3ae7930e8ec8 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -21,12 +21,15 @@ class ConversionTarget;
namespace amdgpu {
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
+#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
RewritePatternSet &patterns,
Chipset chipset);
+
+void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns);
} // namespace amdgpu
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 23f8b8f653b67..ef50367b67a21 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -31,4 +31,20 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
"Chipset that these operations will run on">];
}
+def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
+ let summary = "Resolve memref.extract_strided_metadata on AMDGPU ops";
+ let description = [{
+ This pass rrewrites `memref.extract_strided_metadata` operations
+ targeting the AMDGPU dialect casts.
+
+ It's mainly meant for testing - please incorporate the patterns into your
+ own extract-strided-metadata passes (or run memref's expand-strided-metadata
+ again after this).
+ }];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "memref::MemRefDialect"
+ ];
+}
+
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 24467478623eb..2820d3b7ac2a1 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -111,9 +111,9 @@ static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
: maxThisDim;
}
- return rewriter.create<LLVM::MulOp>(
- loc, convertUnsignedToI32(rewriter, loc, maxIndex),
- createI32Constant(rewriter, loc, elementByteWidth));
+ Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
+ Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
+ return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
}
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
@@ -202,12 +202,9 @@ struct FatRawBufferCastLowering
memrefType)
: descriptor.alignedPtr(rewriter, loc);
- Value offset;
- if (adaptor.getResetOffset())
- offset = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(0));
- else
- offset = descriptor.offset(rewriter, loc);
+ Value offset = adaptor.getResetOffset() ? rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+ rewriter.getIndexAttr(0))
+ : descriptor.offset(rewriter, loc);
bool hasSizes = memrefType.getRank() > 0;
// No need to unpack() and pack() all the individual sizes and strides,
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 5f934714d988a..3d4567bff1e32 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRAMDGPUTransforms
EmulateAtomics.cpp
+ ResolveStridedMetadata.cpp
ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
@@ -11,6 +12,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
MLIRAMDGPUDialect
MLIRAMDGPUUtils
MLIRArithDialect
+ MLIRMemRefDialect
MLIRVectorDialect
MLIRControlFlowDialect
MLIRFuncDialect
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
new file mode 100644
index 0000000000000..3c29dab437b1b
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
@@ -0,0 +1,68 @@
+//===- ResolveStridedMetadata.cpp - AMDGPU expand_strided_metadata ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPURESOLVESTRIDEDMETADATAPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace mlir::amdgpu
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+namespace {
+struct AmdgpuResolveStridedMetadataPass : public amdgpu::impl::AmdgpuResolveStridedMetadataPassBase<AmdgpuResolveStridedMetadataPass> {
+ void runOnOperation() override;
+};
+
+struct ExtractStridedMetadataOnFatRawBufferCastFolder final : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp, PatternRewriter &rewriter) const override {
+ auto castOp = metadataOp.getSource().getDefiningOp<FatRawBufferCastOp>();
+ if (!castOp)
+ return rewriter.notifyMatchFailure(metadataOp, "not a fat raw buffer cast");
+ Location loc = castOp.getLoc();
+ auto sourceMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(loc, castOp.getSource());
+ SmallVector<Value> results;
+ if (metadataOp.getBaseBuffer().use_empty()) {
+ results.push_back(nullptr);
+ } else {
+ auto baseBufferType = cast<MemRefType>(metadataOp.getBaseBuffer().getType());
+ if (baseBufferType == castOp.getResult().getType()) {
+ results.push_back(castOp.getResult());
+ } else {
+ results.push_back(rewriter.create<memref::ReinterpretCastOp>(loc, baseBufferType, castOp.getResult(), /*offset=*/0, /*sizes=*/ArrayRef<int64_t>{}, /*strides=*/ArrayRef<int64_t>{}));
+ }
+ }
+ if (castOp.getResetOffset())
+ results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ else
+ results.push_back(sourceMetadata.getOffset());
+ llvm::append_range(results, sourceMetadata.getSizes());
+ llvm::append_range(results, sourceMetadata.getStrides());
+ rewriter.replaceOp(metadataOp, results);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::amdgpu::populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns) {
+ patterns.add<ExtractStridedMetadataOnFatRawBufferCastFolder>(patterns.getContext());
+}
+
+void AmdgpuResolveStridedMetadataPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateAmdgpuResolveStridedMetadataPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-resolve-strided-metadata.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-resolve-strided-metadata.mlir
new file mode 100644
index 0000000000000..831bb5f0f66ec
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-resolve-strided-metadata.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt -amdgpu-resolve-strided-metadata -split-input-file %s | FileCheck %s
+
+!tSrc = memref<?x?xi32, strided<[?, ?], offset: ?>>
+!tDst = memref<?x?xi32, strided<[?, ?], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>
+!tRes = memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-LABEL: @resolve_metadata_no_offset_reset
+// CHECK-SAME: (%[[arg0:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
+// CHECK-NEXT: %[[cast:.+]] = amdgpu.fat_raw_buffer_cast %[[arg0]]
+// CHECK-NEXT: %{{.+}}, %[[offset:.+]], %[[size:.+]]:2, %[[stride:.+]]:2 = memref.extract_strided_metadata %[[arg0]]
+// CHECK-NEXT: %[[reinterp:.+]] = memref.reinterpret_cast %[[cast]]
+// CHECK-NEXT: return %[[reinterp]], %[[offset]], %[[size]]#0, %[[size]]#1, %[[stride]]#0, %[[stride]]#1
+func.func @resolve_metadata_no_offset_reset(%arg0: !tSrc) -> (!tRes, index, index, index, index, index) {
+ %cast = amdgpu.fat_raw_buffer_cast %arg0 : !tSrc to !tDst
+ %base, %offset, %size:2, %stride:2 = memref.extract_strided_metadata %cast : !tDst -> !tRes, index, index, index, index, index
+ func.return %base, %offset, %size#0, %size#1, %stride#0, %stride#1 : !tRes, index, index, index, index, index
+}
+
+// -----
+
+!tSrc = memref<?x?xi32, strided<[?, ?], offset: ?>>
+!tDst = memref<?x?xi32, strided<[?, ?]>, #amdgpu.address_space<fat_raw_buffer>>
+!tRes = memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-LABEL: @resolve_metadata_offset_reset
+// CHECK-SAME: (%[[arg0:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
+// CHECK-NEXT: %[[offset:.+]] = arith.constant 0 : index
+// CHECK-NEXT: %[[cast:.+]] = amdgpu.fat_raw_buffer_cast %[[arg0]]
+// CHECK-NEXT: %{{.+}}, %{{.+}}, %[[size:.+]]:2, %[[stride:.+]]:2 = memref.extract_strided_metadata %[[arg0]]
+// CHECK-NEXT: %[[reinterp:.+]] = memref.reinterpret_cast %[[cast]]
+// CHECK-NEXT: return %[[reinterp]], %[[offset]], %[[size]]#0, %[[size]]#1, %[[stride]]#0, %[[stride]]#1
+func.func @resolve_metadata_offset_reset(%arg0: !tSrc) -> (!tRes, index, index, index, index, index) {
+ %cast = amdgpu.fat_raw_buffer_cast %arg0 resetOffset : !tSrc to !tDst
+ %base, %offset, %size:2, %stride:2 = memref.extract_strided_metadata %cast : !tDst -> !tRes, index, index, index, index, index
+ func.return %base, %offset, %size#0, %size#1, %stride#0, %stride#1 : !tRes, index, index, index, index, index
+}
+
+// -----
+
+!tSrc = memref<?x?xi32, strided<[?, ?], offset: ?>>
+!tDst = memref<?x?xi32, strided<[?, ?]>, #amdgpu.address_space<fat_raw_buffer>>
+!tRes = memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-LABEL: @resolve_metadata_no_base_ptr
+// CHECK-SAME: (%[[arg0:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
+// CHECK-NEXT: %[[offset:.+]] = arith.constant 0 : index
+// CHECK-NEXT: %[[cast:.+]] = amdgpu.fat_raw_buffer_cast %[[arg0]]
+// CHECK-NEXT: %{{.+}}, %{{.+}}, %[[size:.+]]:2, %[[stride:.+]]:2 = memref.extract_strided_metadata %[[arg0]]
+// CHECK-NEXT: return %[[cast]], %[[offset]], %[[size]]#0, %[[size]]#1, %[[stride]]#0, %[[stride]]#1
+func.func @resolve_metadata_no_base_ptr(%arg0: !tSrc) -> (!tDst, index, index, index, index, index) {
+ %cast = amdgpu.fat_raw_buffer_cast %arg0 resetOffset : !tSrc to !tDst
+ %base, %offset, %size:2, %stride:2 = memref.extract_strided_metadata %cast : !tDst -> !tRes, index, index, index, index, index
+ func.return %cast, %offset, %size#0, %size#1, %stride#0, %stride#1 : !tDst, index, index, index, index, index
+}
>From 44d296f1fa5bb57b93616a2876ce1d957751878d Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Tue, 11 Feb 2025 20:07:41 +0000
Subject: [PATCH 8/8] Oh, right, the clang-format
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 7 +++--
.../Transforms/ResolveStridedMetadata.cpp | 29 +++++++++++++------
2 files changed, 24 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 2820d3b7ac2a1..b7e2f5485b426 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -202,9 +202,10 @@ struct FatRawBufferCastLowering
memrefType)
: descriptor.alignedPtr(rewriter, loc);
- Value offset = adaptor.getResetOffset() ? rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(0))
- : descriptor.offset(rewriter, loc);
+ Value offset = adaptor.getResetOffset()
+ ? rewriter.create<LLVM::ConstantOp>(
+ loc, getIndexType(), rewriter.getIndexAttr(0))
+ : descriptor.offset(rewriter, loc);
bool hasSizes = memrefType.getRank() > 0;
// No need to unpack() and pack() all the individual sizes and strides,
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
index 3c29dab437b1b..4b3d94b4ce2ad 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
@@ -21,27 +21,36 @@ using namespace mlir;
using namespace mlir::amdgpu;
namespace {
-struct AmdgpuResolveStridedMetadataPass : public amdgpu::impl::AmdgpuResolveStridedMetadataPassBase<AmdgpuResolveStridedMetadataPass> {
+struct AmdgpuResolveStridedMetadataPass
+ : public amdgpu::impl::AmdgpuResolveStridedMetadataPassBase<
+ AmdgpuResolveStridedMetadataPass> {
void runOnOperation() override;
};
-struct ExtractStridedMetadataOnFatRawBufferCastFolder final : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+struct ExtractStridedMetadataOnFatRawBufferCastFolder final
+ : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp, PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp,
+ PatternRewriter &rewriter) const override {
auto castOp = metadataOp.getSource().getDefiningOp<FatRawBufferCastOp>();
if (!castOp)
- return rewriter.notifyMatchFailure(metadataOp, "not a fat raw buffer cast");
+ return rewriter.notifyMatchFailure(metadataOp,
+ "not a fat raw buffer cast");
Location loc = castOp.getLoc();
- auto sourceMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(loc, castOp.getSource());
+ auto sourceMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, castOp.getSource());
SmallVector<Value> results;
if (metadataOp.getBaseBuffer().use_empty()) {
results.push_back(nullptr);
} else {
- auto baseBufferType = cast<MemRefType>(metadataOp.getBaseBuffer().getType());
+ auto baseBufferType =
+ cast<MemRefType>(metadataOp.getBaseBuffer().getType());
if (baseBufferType == castOp.getResult().getType()) {
results.push_back(castOp.getResult());
} else {
- results.push_back(rewriter.create<memref::ReinterpretCastOp>(loc, baseBufferType, castOp.getResult(), /*offset=*/0, /*sizes=*/ArrayRef<int64_t>{}, /*strides=*/ArrayRef<int64_t>{}));
+ results.push_back(rewriter.create<memref::ReinterpretCastOp>(
+ loc, baseBufferType, castOp.getResult(), /*offset=*/0,
+ /*sizes=*/ArrayRef<int64_t>{}, /*strides=*/ArrayRef<int64_t>{}));
}
}
if (castOp.getResetOffset())
@@ -56,8 +65,10 @@ struct ExtractStridedMetadataOnFatRawBufferCastFolder final : public OpRewritePa
};
} // namespace
-void mlir::amdgpu::populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns) {
- patterns.add<ExtractStridedMetadataOnFatRawBufferCastFolder>(patterns.getContext());
+void mlir::amdgpu::populateAmdgpuResolveStridedMetadataPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ExtractStridedMetadataOnFatRawBufferCastFolder>(
+ patterns.getContext());
}
void AmdgpuResolveStridedMetadataPass::runOnOperation() {
More information about the Mlir-commits
mailing list