[Mlir-commits] [mlir] [mlir][AMDGPU] Fix raw buffer ptr ops lowering (PR #122293)

Fabian Mora llvmlistbot at llvm.org
Thu Jan 9 07:12:15 PST 2025

https://github.com/fabianmcg created https://github.com/llvm/llvm-project/pull/122293

This patch fixes several bugs in the lowering of AMDGPU raw buffer operations. These bugs include:
 - Incorrectly handling the offset of the memref, causing errors when using subviews. Furthermore, it cannot be assumed the memref offset can be put in a SGPR as it can be a thread dependent value.
 - Using the MaximumOp (float specific op) to calculate the number of records.
 - The number of records in the static shape case.
 - The lowering when index bitwidth=i64.

Furthermore this patch also switches to use MLIR's data layout to get the type size.

>From cd265826503ab59f6e4e6cdbcf6e15a6d3bddc16 Mon Sep 17 00:00:00 2001
From: fabian <6982088+fabianmcg at users.noreply.github.com>
Date: Thu, 9 Jan 2025 06:52:40 -0800
Subject: [PATCH] [mlir][AMDGPU] Fix raw buffer ptr ops lowering

This patch fixes several bugs in the lowering of AMDGPU raw buffer operations.
These bugs include:
 - Incorrectly handling the offset of the memref, causing errors when using subviews. Furthermore, it cannot be assumed the memref offset can be put in a SGPR as it can be a thread dependent value.
 - Using the MaximumOp (float specific op) to calculate the number of records.
 - The number of records in the static shape case.
 - The lowering when index bitwidth=i64.

Furthermore this patch also switches to use MLIR's data layout to get the type size.
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 127 +++++++++++-------
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        |  14 +-
 2 files changed, 83 insertions(+), 58 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4100b086fad8ba..49ac4723c2fb94 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -30,10 +30,23 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::amdgpu;
+/// Convert an unsigned number `val` to i32.
+static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
+                                  Location loc, Value val) {
+  IntegerType i32 = rewriter.getI32Type();
+  // Force check that `val` is of int type.
+  auto valTy = cast<IntegerType>(val.getType());
+  if (i32 == valTy)
+    return val;
+  return valTy.getWidth() > 32
+             ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
+             : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
 static Value createI32Constant(ConversionPatternRewriter &rewriter,
                                Location loc, int32_t value) {
-  Type llvmI32 = rewriter.getI32Type();
-  return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
+  Type i32 = rewriter.getI32Type();
+  return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
 static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
@@ -42,6 +55,28 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
   return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
+/// Returns the linear index used to access an element in the memref.
+static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
+                               Location loc, MemRefDescriptor &memRefDescriptor,
+                               ValueRange indices, ArrayRef<int64_t> strides) {
+  IntegerType i32 = rewriter.getI32Type();
+  Value index;
+  for (int i = 0, e = indices.size(); i < e; ++i) {
+    Value increment = indices[i];
+    if (strides[i] != 1) { // Skip if stride is 1.
+      Value stride =
+          ShapedType::isDynamic(strides[i])
+              ? convertUnsignedToI32(rewriter, loc,
+                                     memRefDescriptor.stride(rewriter, loc, i))
+              : rewriter.create<LLVM::ConstantOp>(loc, i32, strides[i]);
+      increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+    }
+    index =
+        index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+  }
+  return index ? index : createI32Constant(rewriter, loc, 0);
 namespace {
 // Define commonly used chipsets versions for convenience.
 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
@@ -88,17 +123,12 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
     Type i32 = rewriter.getI32Type();
-    Type llvmI32 = this->typeConverter->convertType(i32);
-    Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
+    Type i16 = rewriter.getI16Type();
-    auto toI32 = [&](Value val) -> Value {
-      if (val.getType() == llvmI32)
-        return val;
-      return rewriter.create<LLVM::TruncOp>(loc, llvmI32, val);
-    };
-    int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
+    // Get the type size in bytes.
+    DataLayout dataLayout = DataLayout::closest(gpuOp);
+    int64_t elementByteWidth =
+        dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
     Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
     // If we want to load a vector<NxT> with total size <= 32
@@ -114,7 +144,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
       uint32_t vecLen = dataVector.getNumElements();
-      uint32_t elemBits = dataVector.getElementTypeBitWidth();
+      uint32_t elemBits =
+          dataLayout.getTypeSizeInBits(dataVector.getElementType());
       uint32_t totalBits = elemBits * vecLen;
       bool usePackedFp16 =
           isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
@@ -167,28 +198,37 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     MemRefDescriptor memrefDescriptor(memref);
-    Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
+    Value ptr = memrefDescriptor.bufferPtr(
+        rewriter, loc, *this->getTypeConverter(), memrefType);
     // The stride value is always 0 for raw buffers. This also disables
     // swizling.
     Value stride = rewriter.create<LLVM::ConstantOp>(
-        loc, llvmI16, rewriter.getI16IntegerAttr(0));
+        loc, i16, rewriter.getI16IntegerAttr(0));
+    // Get the number of elements.
     Value numRecords;
-    if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) {
-      numRecords = createI32Constant(
-          rewriter, loc,
-          static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
+    if (memrefType.hasStaticShape() && !llvm::any_of(strides, [](int64_t v) {
+          return ShapedType::isDynamic(v);
+        })) {
+      int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+      ArrayRef<int64_t> shape = memrefType.getShape();
+      for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+        size = std::max(shape[i] * strides[i], size);
+      size = size * elementByteWidth;
+      assert(size < std::numeric_limits<uint32_t>::max() &&
+             "the memref buffer is too large");
+      numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
     } else {
       Value maxIndex;
       for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
-        Value size = toI32(memrefDescriptor.size(rewriter, loc, i));
-        Value stride = toI32(memrefDescriptor.stride(rewriter, loc, i));
-        stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
-        Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
-        maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
-                                                               maxThisDim)
-                            : maxThisDim;
+        Value maxThisDim = rewriter.create<LLVM::MulOp>(
+            loc, memrefDescriptor.size(rewriter, loc, i),
+            memrefDescriptor.stride(rewriter, loc, i));
+        maxIndex =
+            maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+                     : maxThisDim;
-      numRecords = maxIndex;
+      numRecords = rewriter.create<LLVM::MulOp>(
+          loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
     // Flag word:
@@ -218,40 +258,23 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     // Indexing (voffset)
-    Value voffset = createI32Constant(rewriter, loc, 0);
-    for (auto pair : llvm::enumerate(adaptor.getIndices())) {
-      size_t i = pair.index();
-      Value index = pair.value();
-      Value strideOp;
-      if (ShapedType::isDynamic(strides[i])) {
-        strideOp = rewriter.create<LLVM::MulOp>(
-            loc, toI32(memrefDescriptor.stride(rewriter, loc, i)),
-            byteWidthConst);
-      } else {
-        strideOp =
-            createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
-      }
-      index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
-      voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
-    }
-    if (adaptor.getIndexOffset()) {
-      int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
-      Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
+    Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
+                                      adaptor.getIndices(), strides);
+    if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
+        indexOffset && *indexOffset > 0) {
+      Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
       voffset =
           voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
                   : extraOffsetConst;
+    voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
+    // SGPR offset.
     Value sgprOffset = adaptor.getSgprOffset();
     if (!sgprOffset)
       sgprOffset = createI32Constant(rewriter, loc, 0);
-    if (ShapedType::isDynamic(offset))
-      sgprOffset = rewriter.create<LLVM::AddOp>(
-          loc, toI32(memrefDescriptor.offset(rewriter, loc)), sgprOffset);
-    else if (offset > 0)
-      sgprOffset = rewriter.create<LLVM::AddOp>(
-          loc, sgprOffset, createI32Constant(rewriter, loc, offset));
+    sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
     // bit 0: GLC = 0 (atomics drop value, less coherency)
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 4c7515dc810516..92ecbff3e691dc 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -32,14 +32,16 @@ func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
 // CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_strided
 func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<64xi32, strided<[?], offset: ?>>, %idx: i32) -> i32 {
-  // CHECK-DAG: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
-  // CHECK-DAG: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
+  // CHECK: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
+  // CHECK: %[[algn_ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: %[[offset:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: %[[ptr:.*]] = llvm.getelementptr %[[algn_ptr]][%[[offset]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+  // CHECK: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
   // CHECK: %[[size:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-  // CHECK: %[[size32:.*]] = llvm.trunc %[[size]] : i64 to i32
   // CHECK: %[[stride:.*]] = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-  // CHECK: %[[stride32:.*]] = llvm.trunc %[[stride]] : i64 to i32
-  // CHECK: %[[tmp:.*]] = llvm.mul %[[stride32]], %[[elem_size]] : i32
-  // CHECK: %[[numRecords:.*]] = llvm.mul %[[size32]], %[[tmp]] : i32
+  // CHECK: %[[tmp:.*]] = llvm.mul %[[size]], %[[stride]] : i64
+  // CHECK: %[[num_elem:.*]] = llvm.trunc %[[tmp]] : i64 to i32
+  // CHECK: %[[numRecords:.*]] = llvm.mul %[[num_elem]], %[[elem_size]] : i32
   // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
   // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
   // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[rstride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>

More information about the Mlir-commits mailing list