[Mlir-commits] [mlir] [mlir][AMDGPU] Plumb address space 7 through MLIR, add address_space attr. (PR #125594)

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Feb 10 08:34:32 PST 2025


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/125594

>From dd4b1101d255deceb9036aaf3a5a3efd2d9a5745 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 3 Feb 2025 22:49:38 +0000
Subject: [PATCH 1/4] [mlir][AMDGPU] Plumb address space 7 through MLIR, add
 address_space attr.

This commit adds support for casting memrefs into fat raw buffer
pointers to the AMDGPU dialect.

Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow
encapsulating a buffer descriptor (as produced by the make.buffer.rsrc
intrinsic or provided from some API) into a pointer that supports
ordinary pointer operations like load or store. This allows people to
take advantage of the additional semantics that buffer_load and
similar instructions provide without forcing the use of entirely
separate amdgpu.raw_buffer_* operations.

Operations on fat raw buffer pointers are translated to the
corresponding LLVM intrinsics by the backend.

This commit also goes and and defines a #amdgpu.address_space<>
attribute so that AMDGPU-specific memory spaces can be represented.
Only #amdgpu.address_space<fat_raw_buffer> will work correctly with
the memref dialect, but the other possible address spaces are included
for completeness.
---
 .../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h  |  16 +-
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 105 +++++++
 .../mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h    |   2 +
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 256 +++++++++++++-----
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  53 ++++
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 105 ++++++-
 6 files changed, 466 insertions(+), 71 deletions(-)

diff --git a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
index e7637a6013e68ad..bb4e7bc037a373c 100644
--- a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
+++ b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
@@ -16,18 +16,26 @@ namespace mlir {
 
 class LLVMTypeConverter;
 class RewritePatternSet;
+class TypeConverter;
 class Pass;
 
 #define GEN_PASS_DECL_CONVERTAMDGPUTOROCDL
 #include "mlir/Conversion/Passes.h.inc"
 
-/// Note: The ROCDL target does not support the LLVM bfloat type at this time
-/// and so this function will add conversions to change all `bfloat` uses
-/// to `i16`.
-void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter,
+/// Note: This function will also add conversions for the AMDGPU-specific
+/// address spaces, but those can be added separately using
+/// populateAMDGPUMemorySpaceAttributeConversions().
+void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                              RewritePatternSet &patterns,
                                              amdgpu::Chipset chipset);
 
+/// Remap AMDGPU memory spaces to LLVM address spaces
+/// by mapping amdgpu::AddressSpace::fat_raw_buffer to ptr addrspace(7),
+/// amdgpu::AddressSpace::buffer_rsrc to ptr addrspace(8), and
+/// amdgpu::AddressSpace::fat_strided_buffer to ptr addrspace(9).
+void populateAMDGPUMemorySpaceAttributeConversions(
+    TypeConverter &typeConverter);
+
 std::unique_ptr<Pass> createConvertAMDGPUToROCDLPass();
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 69745addfd748ec..6c42849fc71f134 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -9,8 +9,11 @@
 #ifndef AMDGPU
 #define AMDGPU
 
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/EnumAttr.td"
+include "mlir/IR/Properties.td"
 include "mlir/IR/OpBase.td"
 
 def AMDGPU_Dialect : Dialect {
@@ -32,6 +35,45 @@ def AMDGPU_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// AMDGPU general attribute definitions
+//===----------------------------------------------------------------------===//
+
+def AMDGPU_AddressSpace : I32EnumAttr<"AddressSpace",
+    "AMDGPU-specific address spaces",
+    [
+      I32EnumAttrCase<"FatRawBuffer",        0, "fat_raw_buffer">,
+      I32EnumAttrCase<"BufferRsrc",          1, "buffer_rsrc">,
+      I32EnumAttrCase<"FatStructuredBuffer", 2, "fat_structured_buffer">,
+    ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
+    "address_space"> {
+  let description = [{
+    AMDGPU-specific memory spaces that may not have exact analogues on other
+    GPU targets or backends.
+
+    - fat_raw_buffer is the memory space used when a memref is stored as
+    as a "buffer fat pointer" - that is, a buffer resource (that is set up to
+    use raw byte-level indexing) along with its offset. The AMDGPU backend
+    implements ptr addrspace(7) to represent these fat pointers so that
+    buffer resources (which allow advanced features like bounds checking or
+    cache swizzling) can be used like ordinary LLVM pointers or memrefs.
+    See also the fat_raw_buffer_cast operation
+    - buffer_rsrc is the memory space for ptr addrspace(8), representing a
+    buffer resource. It should not be used for memrefs, since it does not support
+    indexing
+    - fat_structured_buffer represents ptr addrspace(9), a buffer resource
+    that carries both an index and offset field, which are used for complex
+    structured indexing that is primarily seen in graphics applications. This
+    is also incompatible with the simple indexing model supported by memref.
+  }];
+  let assemblyFormat = [{ `<` $value `>` }];
+}
+
 //===----------------------------------------------------------------------===//
 // AMDGPU Op definitions
 //===----------------------------------------------------------------------===//
@@ -118,6 +160,69 @@ def AMDGPU_PackedStochRoundFp8Op :
   let hasVerifier = 1;
 }
 
+def AMDGPU_FatRawBufferCastOp :
+    AMDGPU_Op<"fat_raw_buffer_cast",
+      [Pure,
+       DeclareOpInterfaceMethods<InferTypeOpInterface>,
+       ViewLikeOpInterface, AttrSizedOperandSegments]>,
+    Arguments<(ins AnyMemRef:$source,
+      Optional<I32>:$validBytes,
+      Optional<I<14>>:$cacheSwizzleStride,
+      DefaultValuedProp<BoolProp, "true">:$boundsCheck,
+      UnitProp:$resetOffset)>,
+    Results<(outs AnyMemRef:$result)> {
+  let summary = "Create a raw buffer fat pointer that matches `memref`";
+  let description = [{
+    Wraps the memory pointed to by `in` as a raw buffer fat pointer, or,
+    in LLVM terms, a ptr addrspace(7), returning a memref that has the same
+    sizes and layout but the `#amdgpu.address_space<fat_raw_buffer>`
+    address space.
+
+    This memref can be used with standard memref operations like `memref.load`,
+    `memref.store`, and `memref.atomicrmw`, which will be lowered to the relevant
+    buffer intrinsics. (`vector.masked_load/store` will work once there's backend
+    support for lowering them, and then this document will be updated)
+
+    If `validBytes` is given, it is the number of bytes that will be valid as
+    an offset to `out`. If it is not provided, this will be inferred from
+    the size of the memref during lowering. This size is
+    max_d (sizes[d] * strides[d]) * sizeof(element type)..
+
+    The flags of the buffer descriptor will be set up to enable raw usage -
+    for example, stride = 0, add_tid = 0, and so on. The `boundsCheck`
+    property determines if bounds checking is enabled or not (on architectures
+    where this can be controlled - that is, on RDNA chips).
+
+    If `cacheSwizzleStride` is provided, L1 cache swizzling will be enabled
+    on architectures that support it. This swizzling, unlike the main swizzling
+    mode (whose usage makes a buffer non-raw) does not affect index calculaton,
+    but does affect cache behavior. Mixing access between cache-swizzled raw
+    buffers and other forms of memory access, like ordinary pointer loads or
+    unswizzled buffer pointers can cause incorrect behavior and must be avoided.
+
+    This operation preserves the sizes, strides, and offset of the input
+    memref - they'll be added in by `memref.load` later. However, if
+    `resetOffset` is set, that offset will be added to the base pointer.
+    If the value of the memref's offset is not independent of the lane/thread ID,
+    this will lead to substantially decreased performance due to the need for
+    a waterfall loop on the base address of the buffer resource.
+  }];
+
+  let extraClassDeclaration = [{
+    Value getViewSource() { return getSource(); }
+  }];
+
+  let assemblyFormat = [{
+    $source oilist (`validBytes` `(` $validBytes `)`
+      | `cacheSwizzleStride` `(` $cacheSwizzleStride `)`
+      | `boundsCheck` `(` $boundsCheck `)`
+      | `resetOffset` $resetOffset )
+    attr-dict `:` type($source) `to` type($result)
+  }];
+
+  let hasVerifier = 1;
+}
+
 /// Raw buffer load
 def AMDGPU_RawBufferLoadOp :
     AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 0a2e6bb5e9fe494..3de57c923178ad9 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -18,7 +18,9 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc"
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 9fb51f0bc1f1ea7..173b0b612ca0df5 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -19,6 +19,8 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 
+#include "../LLVMCommon/MemRefDescriptor.h"
+
 #include "llvm/ADT/STLExtras.h"
 #include <optional>
 
@@ -30,6 +32,11 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::amdgpu;
 
+// Define commonly used chipsets versions for convenience.
+static constexpr Chipset kGfx908 = Chipset(9, 0, 8);
+static constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
+static constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+
 /// Convert an unsigned number `val` to i32.
 static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
                                   Location loc, Value val) {
@@ -76,11 +83,164 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
   return index ? index : createI32Constant(rewriter, loc, 0);
 }
 
+/// Compute the contents of the `num_records` field for a given memref
+/// descriptor - that is, the number of bytes that's one element past the
+/// greatest possible valid index into the memref.
+static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
+                           MemRefType memrefType,
+                           MemRefDescriptor &memrefDescriptor,
+                           ArrayRef<int64_t> strides,
+                           uint32_t elementByteWidth) {
+  if (memrefType.hasStaticShape() &&
+      !llvm::any_of(strides, ShapedType::isDynamic)) {
+    int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+    ArrayRef<int64_t> shape = memrefType.getShape();
+    for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+      size = std::max(shape[i] * strides[i], size);
+    size = size * elementByteWidth;
+    assert(size < std::numeric_limits<uint32_t>::max() &&
+           "the memref buffer is too large");
+    return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
+  }
+  Value maxIndex;
+  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
+    Value size = memrefDescriptor.size(rewriter, loc, i);
+    Value stride = memrefDescriptor.stride(rewriter, loc, i);
+    Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+    maxIndex = maxIndex
+                   ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+                   : maxThisDim;
+  }
+  return rewriter.create<LLVM::MulOp>(
+      loc, convertUnsignedToI32(rewriter, loc, maxIndex),
+      createI32Constant(rewriter, loc, elementByteWidth));
+}
+
+static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
+                            Value basePointer, Value numRecords,
+                            bool boundsCheck, amdgpu::Chipset chipset,
+                            Value cacheSwizzleStride = nullptr) {
+  // The stride value is generally 0. However, on MI-300 and onward, you can
+  // enable a cache swizzling mode by setting bit 14 of the stride field
+  // and setting that stride to a cache stride.
+  Type i16 = rewriter.getI16Type();
+  Value stride;
+  if (chipset.majorVersion == 9 && chipset >= kGfx940 && cacheSwizzleStride) {
+    Value cacheStrideZext =
+        rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
+    Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
+        loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+    stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
+                                         /*isDisjoint=*/true);
+  } else {
+    stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
+                                               rewriter.getI16IntegerAttr(0));
+  }
+  // Get the number of elements.
+  // Flag word:
+  // bits 0-11: dst sel, ignored by these intrinsics
+  // bits 12-14: data format (ignored, must be nonzero, 7=float)
+  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
+  // bit 19: In nested heap (0 here)
+  // bit 20: Behavior on unmap (0 means  "return 0 / ignore")
+  // bits 21-22: Index stride for swizzles (N/A)
+  // bit 23: Add thread ID (0)
+  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
+  // bits 25-26: Reserved (0)
+  // bit 27: Buffer is non-volatile (CDNA only)
+  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
+  //  none, 3 = either swizzles or testing against offset field) RDNA only
+  // bits 30-31: Type (must be 0)
+  uint32_t flags = (7 << 12) | (4 << 15);
+  if (chipset.majorVersion >= 10) {
+    flags |= (1 << 24);
+    uint32_t oob = boundsCheck ? 3 : 2;
+    flags |= (oob << 28);
+  }
+  Value flagsConst = createI32Constant(rewriter, loc, flags);
+  Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
+  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
+      loc, rsrcType, basePointer, stride, numRecords, flagsConst);
+  return resource;
+}
+
 namespace {
-// Define commonly used chipsets versions for convenience.
-constexpr Chipset kGfx908 = Chipset(9, 0, 8);
-constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
-constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+struct FatRawBufferCastLowering
+    : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
+  FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
+        chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value memRef = adaptor.getSource();
+    Value unconvertedMemref = op.getSource();
+    MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
+    MemRefDescriptor descriptor(memRef);
+
+    DataLayout dataLayout = DataLayout::closest(op);
+    int64_t elementByteWidth =
+        dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
+
+    int64_t unusedOffset = 0;
+    SmallVector<int64_t, 5> strideVals;
+    if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
+      return op.emitOpError("Can't lower non-stride-offset memrefs");
+
+    Value numRecords = adaptor.getValidBytes();
+    if (!numRecords)
+      numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
+                                 strideVals, elementByteWidth);
+
+    Value basePointer;
+    if (adaptor.getResetOffset())
+      basePointer =
+          descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), memrefType);
+    else
+      basePointer = descriptor.alignedPtr(rewriter, loc);
+
+    Value offset;
+    if (adaptor.getResetOffset())
+      offset = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+                                                 rewriter.getIndexAttr(0));
+    else
+      offset = descriptor.offset(rewriter, loc);
+
+    // No need to unpack() and pack() all the individual sizes and strides,
+    // so we'll just extract the arrays.
+    Value sizes = rewriter.create<LLVM::ExtractValueOp>(
+        loc, descriptor, kSizePosInMemRefDescriptor);
+    Value strides = rewriter.create<LLVM::ExtractValueOp>(
+        loc, descriptor, kStridePosInMemRefDescriptor);
+
+    Value rsrc = makeBufferRsrc(rewriter, loc, basePointer, numRecords,
+                                adaptor.getBoundsCheck(), chipset,
+                                adaptor.getCacheSwizzleStride());
+    Value fatPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+        loc, LLVM::LLVMPointerType::get(op.getContext(), 7), rsrc);
+
+    Value result = MemRefDescriptor::undef(
+        rewriter, loc,
+        getTypeConverter()->convertType(op.getResult().getType()));
+    result = rewriter.create<LLVM::InsertValueOp>(
+        loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
+    result = rewriter.create<LLVM::InsertValueOp>(
+        loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
+    result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
+                                                  kOffsetPosInMemRefDescriptor);
+    result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
+                                                  kSizePosInMemRefDescriptor);
+    result = rewriter.create<LLVM::InsertValueOp>(loc, result, strides,
+                                                  kStridePosInMemRefDescriptor);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
 
 /// Define lowering patterns for raw buffer ops
 template <typename GpuOp, typename Intrinsic>
@@ -122,7 +282,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
     Type i32 = rewriter.getI32Type();
-    Type i16 = rewriter.getI16Type();
 
     // Get the type size in bytes.
     DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -199,60 +358,10 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     Value ptr = memrefDescriptor.bufferPtr(
         rewriter, loc, *this->getTypeConverter(), memrefType);
-    // The stride value is always 0 for raw buffers. This also disables
-    // swizling.
-    Value stride = rewriter.create<LLVM::ConstantOp>(
-        loc, i16, rewriter.getI16IntegerAttr(0));
-    // Get the number of elements.
-    Value numRecords;
-    if (memrefType.hasStaticShape() &&
-        !llvm::any_of(strides, ShapedType::isDynamic)) {
-      int64_t size = memrefType.getRank() == 0 ? 1 : 0;
-      ArrayRef<int64_t> shape = memrefType.getShape();
-      for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
-        size = std::max(shape[i] * strides[i], size);
-      size = size * elementByteWidth;
-      assert(size < std::numeric_limits<uint32_t>::max() &&
-             "the memref buffer is too large");
-      numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
-    } else {
-      Value maxIndex;
-      for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
-        Value size = memrefDescriptor.size(rewriter, loc, i);
-        Value stride = memrefDescriptor.stride(rewriter, loc, i);
-        Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
-        maxIndex =
-            maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
-                     : maxThisDim;
-      }
-      numRecords = rewriter.create<LLVM::MulOp>(
-          loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
-    }
-
-    // Flag word:
-    // bits 0-11: dst sel, ignored by these intrinsics
-    // bits 12-14: data format (ignored, must be nonzero, 7=float)
-    // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
-    // bit 19: In nested heap (0 here)
-    // bit 20: Behavior on unmap (0 means  "return 0 / ignore")
-    // bits 21-22: Index stride for swizzles (N/A)
-    // bit 23: Add thread ID (0)
-    // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
-    // bits 25-26: Reserved (0)
-    // bit 27: Buffer is non-volatile (CDNA only)
-    // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
-    //  none, 3 = either swizzles or testing against offset field) RDNA only
-    // bits 30-31: Type (must be 0)
-    uint32_t flags = (7 << 12) | (4 << 15);
-    if (chipset.majorVersion >= 10) {
-      flags |= (1 << 24);
-      uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
-      flags |= (oob << 28);
-    }
-    Value flagsConst = createI32Constant(rewriter, loc, flags);
-    Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
-    Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
-        loc, rsrcType, ptr, stride, numRecords, flagsConst);
+    Value numRecords = getNumRecords(
+        rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
+    Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
+                                    adaptor.getBoundsCheck(), chipset);
     args.push_back(resource);
 
     // Indexing (voffset)
@@ -1062,11 +1171,32 @@ struct ConvertAMDGPUToROCDLPass
 };
 } // namespace
 
-void mlir::populateAMDGPUToROCDLConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    Chipset chipset) {
+void mlir::populateAMDGPUMemorySpaceAttributeConversions(
+    TypeConverter &typeConverter) {
+  typeConverter.addTypeAttributeConversion(
+      [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
+          -> TypeConverter::AttributeConversionResult {
+        MLIRContext *ctx = as.getContext();
+        Type i64 = IntegerType::get(ctx, 64);
+        switch (as.getValue()) {
+        case amdgpu::AddressSpace::FatRawBuffer:
+          return IntegerAttr::get(i64, 7);
+        case amdgpu::AddressSpace::BufferRsrc:
+          return IntegerAttr::get(i64, 8);
+        case amdgpu::AddressSpace::FatStructuredBuffer:
+          return IntegerAttr::get(i64, 9);
+        }
+        return TypeConverter::AttributeConversionResult::abort();
+      });
+}
+
+void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
+                                                   RewritePatternSet &patterns,
+                                                   Chipset chipset) {
+  populateAMDGPUMemorySpaceAttributeConversions(converter);
   patterns
-      .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
+      .add<FatRawBufferCastLowering,
+           RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
            RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
            RawBufferOpLowering<RawBufferAtomicFaddOp,
                                ROCDL::RawPtrBufferAtomicFaddOp>,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 271ca382e2f0bac..e944b6b5acb0c61 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -59,6 +59,59 @@ LogicalResult PackedStochRoundFp8Op::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// FatRawBuferCastOp
+//===----------------------------------------------------------------------===//
+
+/// Convert the type `source` to one with the same sizes and strides - and
+/// offset, unless `stripOffset` is true, in which case the offset is reset to
+/// 0, If the offset should be reset but the layout of `source` isn't either the
+/// identity layout or a strided layout, this function fails.
+static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
+                                                     bool resetOffset) {
+  MLIRContext *ctx = source.getContext();
+  MemRefType::Builder mb(source);
+  mb.setMemorySpace(
+      amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
+  MemRefLayoutAttrInterface layout = source.getLayout();
+  if (resetOffset && !layout.isIdentity()) {
+    auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
+    if (!stridedLayout)
+      return failure();
+    mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
+  }
+  return (MemRefType)(mb);
+}
+
+LogicalResult FatRawBufferCastOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  Adaptor adaptor(operands, attributes, properties, regions);
+  auto sourceType =
+      dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
+  if (!sourceType)
+    return failure();
+  FailureOr<MemRefType> resultType =
+      getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
+  if (failed(resultType))
+    return failure();
+  inferredReturnTypes = SmallVector<Type>{*resultType};
+  return success();
+}
+
+LogicalResult FatRawBufferCastOp::verify() {
+  FailureOr<MemRefType> expectedResultType =
+      getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
+  if (failed(expectedResultType))
+    return emitOpError("source type ")
+           << getSource().getType() << " can't have its offset reset";
+  if (getResult().getType() != *expectedResultType)
+    return emitOpError("expected result type to be ")
+           << *expectedResultType << " but got " << getResult().getType();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // RawBuffer*Op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 062b63c076c3cbe..921975862a4bd32 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -1,13 +1,107 @@
 // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx908 | FileCheck %s --check-prefixes=CHECK,GFX9,GFX908
 // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx90a | FileCheck %s --check-prefixes=CHECK,GFX9,GFX90A
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9,GFX942
 // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10,RDNA
 // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11,RDNA
 // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12,RDNA
 
+// Note: #gpu.address_space<global> is hardcoded to `1` here because the
+// test pass doesn't set up the GPU address space conversions.
+
+#gpu_global_addrspace = 1
+
+// CHECK-LABEL: func @fat_raw_buffer_cast
+func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+  // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<8xi32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[base:.*]] = llvm.extractvalue %[[desc]][1]
+  // CHECK-DAG: %[[offset:.*]] = llvm.extractvalue %[[desc]][2]
+  // CHECK-DAG: %[[sizes:.*]] = llvm.extractvalue %[[desc]][3]
+  // CHECK-DAG: %[[strides:.*]] = llvm.extractvalue %[[desc]][4]
+  // CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(32 : i32) : i32
+  // CHECK-DAG: %[[strideArg:.*]] = llvm.mlir.constant(0 : i16) : i16
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[base]], %[[strideArg]], %[[numRecords]], %[[flags]]
+  // CHECK: %[[fatBuf:.*]] = llvm.addrspacecast %[[rsrc]] : !llvm.ptr<8> to !llvm.ptr<7>
+  // CHECK: %[[ret0:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<7>, ptr<7>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[ret1:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret0]][0]
+  // CHECK: %[[ret2:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret1]][1]
+  // CHECK: %[[ret3:.*]] = llvm.insertvalue %[[offset]], %[[ret2]][2]
+  // CHECK: %[[ret4:.*]] = llvm.insertvalue %[[sizes]], %[[ret3]][3]
+  // CHECK: %[[ret5:.*]] = llvm.insertvalue %[[strides]], %[[ret4]][4]
+  // CHECK: builtin.unrealized_conversion_cast %[[ret5]]
+  %ret = amdgpu.fat_raw_buffer_cast %buf : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dyn_size_offset
+func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> {
+  // CHECK: %[[size0:.*]] = llvm.extractvalue %{{.*}}[3, 0]
+  // CHECK: %[[stride0:.*]] = llvm.extractvalue %{{.*}}[4, 0]
+  // CHECK: %[[maxVals:.*]] = llvm.mul %[[size0]], %[[stride0]]
+  // CHECK: %[[maxValsI32:.*]] = llvm.trunc %[[maxVals]] : i64 to i32
+  // CHECK: %[[byteSize:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[numRecords:.*]] = llvm.mul %[[maxValsI32]], %[[byteSize]]
+  // CHECK: %[[offset:.*]] = llvm.extractvalue %{{.*}}[2]
+  // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
+  // CHECK: llvm.insertvalue %[[offset]], %{{.*}}[2]
+  %ret = amdgpu.fat_raw_buffer_cast %buf : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_reset_offset
+func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+  // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<?xi32, strided<[1], offset: ?>, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[memRefPtr:.*]] = llvm.extractvalue %[[desc]][1]
+  // CHECK-DAG: %[[memRefOff:.*]] = llvm.extractvalue %[[desc]][2]
+  // CHECK-DAG: %[[basePtr:.*]] = llvm.getelementptr %[[memRefPtr]][%[[memRefOff]]]
+  // CHECK-DAG: %[[zeroOff:.*]] = llvm.mlir.constant(0 : index) : i64
+  // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[basePtr]], %{{.*}}, %{{.*}}, %{{.*}}
+  // CHECK: %[[fatBuf:.*]] = llvm.addrspacecast %[[rsrc]]
+  // CHECK: llvm.insertvalue %[[fatBuf]], %{{.*}}[1]
+  // CHECK: llvm.insertvalue %[[zeroOff]], %{{.*}}[2]
+  %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_valid_bytes
+func.func @fat_raw_buffer_cast_valid_bytes(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+  // CHECK: %[[numRecords:.*]] = arith.constant -1 : i32
+  // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
+  %cu32_max = arith.constant 0xffffffff : i32
+  %ret = amdgpu.fat_raw_buffer_cast %buf validBytes(%cu32_max) : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_bounds_check
+func.func @fat_raw_buffer_cast_bounds_check(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(553807872 : i32)
+  // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %{{.*}}, %[[flags]]
+  %ret = amdgpu.fat_raw_buffer_cast %buf boundsCheck(false) : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_cache_swizzle
+// CHECK-SAME: (%{{.*}}: memref<64x64xi32, 1>, %[[stride:.*]]: i14)
+func.func @fat_raw_buffer_cast_cache_swizzle(%buf: memref<64x64xi32, #gpu_global_addrspace>, %stride: i14) -> memref<64x64xi32, #amdgpu.address_space<fat_raw_buffer>> {
+  // GFX908: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
+  // GFX90A: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
+  // RDNA: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
+  // GFX942: %[[asI16:.*]] = llvm.zext %[[stride]] : i14 to i16
+  // GFX942: %[[cacheSwizzleOn:.*]] = llvm.mlir.constant(16384 : i16) : i16
+  // GFX942: %[[stride:.*]] = llvm.or disjoint %[[asI16]], %[[cacheSwizzleOn]]
+  // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %{{.*}}, %{{.*}}
+  %ret = amdgpu.fat_raw_buffer_cast %buf cacheSwizzleStride(%stride) : memref<64x64xi32, #gpu_global_addrspace> to memref<64x64xi32, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<64x64xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
 // CHECK-LABEL: func @gpu_gcn_raw_buffer_load_scalar_i32
 func.func @gpu_gcn_raw_buffer_load_scalar_i32(%buf: memref<i32>) -> i32 {
-  // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
+  // Extra constant for byte width
+  // CHECK: llvm.mlir.constant(4 : i32)
   // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(4 : i32)
+  // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
   // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
   // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
   // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>
@@ -19,8 +113,8 @@ func.func @gpu_gcn_raw_buffer_load_scalar_i32(%buf: memref<i32>) -> i32 {
 
 // CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32
 func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
-  // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
   // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
+  // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
   // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
   // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
   // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>
@@ -37,7 +131,6 @@ func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<16x16xi32, strided<[
     // CHECK: %[[algn_ptr:.*]] = llvm.extractvalue %[[descriptor]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
     // CHECK: %[[offset:.*]] = llvm.extractvalue %[[descriptor]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
     // CHECK: %[[ptr:.*]] = llvm.getelementptr %[[algn_ptr]][%[[offset]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
-    // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
     // CHECK: %[[sz_i:.*]] = llvm.extractvalue %[[descriptor]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
     // CHECK: %[[stride_i:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
     // CHECK: %[[ext_i:.*]] = llvm.mul %[[sz_i]], %[[stride_i]] : i64
@@ -46,7 +139,9 @@ func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<16x16xi32, strided<[
     // CHECK: %[[ext_j:.*]] = llvm.mul %[[sz_j]], %[[stride_j]] : i64
     // CHECK: %[[num_records:.*]] = llvm.intr.umax(%[[ext_i]], %[[ext_j]]) : (i64, i64) -> i64
     // CHECK: %[[num_rec_i32:.*]] = llvm.trunc %[[num_records]] : i64 to i32
-    // CHECK: %[[num_rec_bytes_i32:.*]] = llvm.mul %[[num_rec_i32]], %[[elem_size]] : i32
+    // CHECK: %[[elem_size_2:.*]] = llvm.mlir.constant(4 : i32) : i32
+    // CHECK: %[[num_rec_bytes_i32:.*]] = llvm.mul %[[num_rec_i32]], %[[elem_size_2]] : i32
+    // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
     // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %[[stride]], %[[num_rec_bytes_i32]], %{{.*}} : !llvm.ptr to <8>
     // CHECK: %[[stride_i_1:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
     // CHECK: %[[stride_i_i32:.*]] = llvm.trunc %[[stride_i_1]] : i64 to i32
@@ -289,6 +384,8 @@ func.func @lds_barrier() {
   // GFX908-SAME: ";;;WARNING: BREAKS DEBUG WATCHES\0As_waitcnt lgkmcnt(0)\0As_barrier"
   // GFX90A: rocdl.s.waitcnt -7937
   // GFX90A-NEXT: rocdl.s.barrier
+  // GFX942: rocdl.s.waitcnt -7937
+  // GFX942-NEXT: rocdl.s.barrier
   // GFX10:  rocdl.s.waitcnt -16129
   // GFX10-NEXT: rocdl.s.barrier
   // GFX11:  llvm.inline_asm has_side_effects asm_dialect = att

>From 0e12ca086599130d3b90c10e76d190efa674acef Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Thu, 6 Feb 2025 04:06:30 +0000
Subject: [PATCH 2/4] Fixes to AMDGPUToROCDL PR

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  2 +-
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 23 +++++++++++--------
 2 files changed, 15 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 6c42849fc71f134..c7284b4eabf055b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -71,7 +71,7 @@ def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
     structured indexing that is primarily seen in graphics applications. This
     is also incompatible with the simple indexing model supported by memref.
   }];
-  let assemblyFormat = [{ `<` $value `>` }];
+  let assemblyFormat = "`<` $value `>`";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 173b0b612ca0df5..5a5fb9d48c983ae 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -210,12 +210,16 @@ struct FatRawBufferCastLowering
     else
       offset = descriptor.offset(rewriter, loc);
 
+    bool hasSizes = memrefType.getRank() > 0;
     // No need to unpack() and pack() all the individual sizes and strides,
     // so we'll just extract the arrays.
-    Value sizes = rewriter.create<LLVM::ExtractValueOp>(
-        loc, descriptor, kSizePosInMemRefDescriptor);
-    Value strides = rewriter.create<LLVM::ExtractValueOp>(
-        loc, descriptor, kStridePosInMemRefDescriptor);
+    Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
+                                 loc, descriptor, kSizePosInMemRefDescriptor)
+                           : Value{};
+    Value strides = hasSizes
+                        ? rewriter.create<LLVM::ExtractValueOp>(
+                              loc, descriptor, kStridePosInMemRefDescriptor)
+                        : Value{};
 
     Value rsrc = makeBufferRsrc(rewriter, loc, basePointer, numRecords,
                                 adaptor.getBoundsCheck(), chipset,
@@ -232,11 +236,12 @@ struct FatRawBufferCastLowering
         loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
     result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
                                                   kOffsetPosInMemRefDescriptor);
-    result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
-                                                  kSizePosInMemRefDescriptor);
-    result = rewriter.create<LLVM::InsertValueOp>(loc, result, strides,
-                                                  kStridePosInMemRefDescriptor);
-
+    if (hasSizes) {
+      result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
+                                                    kSizePosInMemRefDescriptor);
+      result = rewriter.create<LLVM::InsertValueOp>(
+          loc, result, strides, kStridePosInMemRefDescriptor);
+    }
     rewriter.replaceOp(op, result);
     return success();
   }

>From 63039ff327fc9f4e0d3936ab0ad45f1f09027de3 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Thu, 6 Feb 2025 19:33:34 +0000
Subject: [PATCH 3/4] Fix assemply format, fix poison, add test for 0D

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  2 +-
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 22 ++++++++++++++++++-
 2 files changed, 22 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5a5fb9d48c983ae..d66c44ea71a6d72 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -227,7 +227,7 @@ struct FatRawBufferCastLowering
     Value fatPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
         loc, LLVM::LLVMPointerType::get(op.getContext(), 7), rsrc);
 
-    Value result = MemRefDescriptor::undef(
+    Value result = MemRefDescriptor::poison(
         rewriter, loc,
         getTypeConverter()->convertType(op.getResult().getType()));
     result = rewriter.create<LLVM::InsertValueOp>(
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 921975862a4bd32..fbe88dcd57cee91 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -23,7 +23,7 @@ func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> me
   // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
   // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[base]], %[[strideArg]], %[[numRecords]], %[[flags]]
   // CHECK: %[[fatBuf:.*]] = llvm.addrspacecast %[[rsrc]] : !llvm.ptr<8> to !llvm.ptr<7>
-  // CHECK: %[[ret0:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<7>, ptr<7>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[ret0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<7>, ptr<7>, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[ret1:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret0]][0]
   // CHECK: %[[ret2:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret1]][1]
   // CHECK: %[[ret3:.*]] = llvm.insertvalue %[[offset]], %[[ret2]][2]
@@ -34,6 +34,26 @@ func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> me
   return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
 }
 
+// CHECK-LABEL: func @fat_raw_buffer_cast_0d
+func.func @fat_raw_buffer_cast_0d(%buf: memref<i32, #gpu_global_addrspace>) -> memref<i32, #amdgpu.address_space<fat_raw_buffer>> {
+  // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<i32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64)>
+  // CHECK-DAG: %[[base:.*]] = llvm.extractvalue %[[desc]][1]
+  // CHECK-DAG: %[[offset:.*]] = llvm.extractvalue %[[desc]][2]
+  // CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK-DAG: %[[strideArg:.*]] = llvm.mlir.constant(0 : i16) : i16
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[base]], %[[strideArg]], %[[numRecords]], %[[flags]]
+  // CHECK: %[[fatBuf:.*]] = llvm.addrspacecast %[[rsrc]] : !llvm.ptr<8> to !llvm.ptr<7>
+  // CHECK: %[[ret0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<7>, ptr<7>, i64)>
+  // CHECK: %[[ret1:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret0]][0]
+  // CHECK: %[[ret2:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret1]][1]
+  // CHECK: %[[ret3:.*]] = llvm.insertvalue %[[offset]], %[[ret2]][2]
+  // CHECK: builtin.unrealized_conversion_cast %[[ret3]]
+  %ret = amdgpu.fat_raw_buffer_cast %buf : memref<i32, #gpu_global_addrspace> to memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
 // CHECK-LABEL: func @fat_raw_buffer_cast_dyn_size_offset
 func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> {
   // CHECK: %[[size0:.*]] = llvm.extractvalue %{{.*}}[3, 0]

>From 9e042c8e425c6125e0984a8ae937f8d8ba2ad66c Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 10 Feb 2025 10:34:23 -0600
Subject: [PATCH 4/4] Review comments

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
Co-authored-by: Prashant Kumar <pk5561 at gmail.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td       | 6 +++---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 7 +------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp        | 2 +-
 3 files changed, 5 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index c7284b4eabf055b..56bef351e3bda1e 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -59,7 +59,7 @@ def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
     - fat_raw_buffer is the memory space used when a memref is stored as
     as a "buffer fat pointer" - that is, a buffer resource (that is set up to
     use raw byte-level indexing) along with its offset. The AMDGPU backend
-    implements ptr addrspace(7) to represent these fat pointers so that
+    implements `ptr addrspace(7)` to represent these fat pointers so that
     buffer resources (which allow advanced features like bounds checking or
     cache swizzling) can be used like ordinary LLVM pointers or memrefs.
     See also the fat_raw_buffer_cast operation
@@ -195,7 +195,7 @@ def AMDGPU_FatRawBufferCastOp :
 
     If `cacheSwizzleStride` is provided, L1 cache swizzling will be enabled
     on architectures that support it. This swizzling, unlike the main swizzling
-    mode (whose usage makes a buffer non-raw) does not affect index calculaton,
+    mode (whose usage makes a buffer non-raw) does not affect index calculation,
     but does affect cache behavior. Mixing access between cache-swizzled raw
     buffers and other forms of memory access, like ordinary pointer loads or
     unswizzled buffer pointers can cause incorrect behavior and must be avoided.
@@ -203,7 +203,7 @@ def AMDGPU_FatRawBufferCastOp :
     This operation preserves the sizes, strides, and offset of the input
     memref - they'll be added in by `memref.load` later. However, if
     `resetOffset` is set, that offset will be added to the base pointer.
-    If the value of the memref's offset is not independent of the lane/thread ID,
+    If the value of the memref's offset is not uniform (independent of the lane/thread ID),
     this will lead to substantially decreased performance due to the need for
     a waterfall loop on the base address of the buffer resource.
   }];
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index d66c44ea71a6d72..cbfb45479f88456 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -196,12 +196,7 @@ struct FatRawBufferCastLowering
       numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
                                  strideVals, elementByteWidth);
 
-    Value basePointer;
-    if (adaptor.getResetOffset())
-      basePointer =
-          descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), memrefType);
-    else
-      basePointer = descriptor.alignedPtr(rewriter, loc);
+Value basePointer = adaptor.getResetOffset()  ?  descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), memrefType) : descriptor.alignedPtr(rewriter, loc);
 
     Value offset;
     if (adaptor.getResetOffset())
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e944b6b5acb0c61..d2bfb863244d989 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -65,7 +65,7 @@ LogicalResult PackedStochRoundFp8Op::verify() {
 
 /// Convert the type `source` to one with the same sizes and strides - and
 /// offset, unless `stripOffset` is true, in which case the offset is reset to
-/// 0, If the offset should be reset but the layout of `source` isn't either the
+/// 0, if the offset should be reset but the layout of `source` isn't either the
 /// identity layout or a strided layout, this function fails.
 static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
                                                      bool resetOffset) {



More information about the Mlir-commits mailing list