[Mlir-commits] [mlir] [mlir][AMDGPU] Fix raw buffer ptr ops lowering (PR #122293)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 9 07:20:32 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Fabian Mora (fabianmcg)
<details>
<summary>Changes</summary>
This patch fixes several bugs in the lowering of AMDGPU raw buffer operations. These bugs include:
- Incorrectly handling the offset of the memref, causing errors when using subviews. Furthermore, it cannot be assumed the memref offset can be put in a SGPR as it can be a thread dependent value.
- Using the MaximumOp (float specific op) to calculate the number of records.
- The number of records in the static shape case.
- The lowering when index bitwidth=i64.
Furthermore this patch also switches to use MLIR's data layout to get the type size.
---
Full diff: https://github.com/llvm/llvm-project/pull/122293.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+75-52)
- (modified) mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir (+8-6)
``````````diff
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4100b086fad8ba..49ac4723c2fb94 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -30,10 +30,23 @@ namespace mlir {
using namespace mlir;
using namespace mlir::amdgpu;
+/// Convert an unsigned number `val` to i32.
+static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
+ Location loc, Value val) {
+ IntegerType i32 = rewriter.getI32Type();
+ // Force check that `val` is of int type.
+ auto valTy = cast<IntegerType>(val.getType());
+ if (i32 == valTy)
+ return val;
+ return valTy.getWidth() > 32
+ ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
+ : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
+}
+
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
- Type llvmI32 = rewriter.getI32Type();
- return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
+ Type i32 = rewriter.getI32Type();
+ return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
}
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
@@ -42,6 +55,28 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
}
+/// Returns the linear index used to access an element in the memref.
+static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
+ Location loc, MemRefDescriptor &memRefDescriptor,
+ ValueRange indices, ArrayRef<int64_t> strides) {
+ IntegerType i32 = rewriter.getI32Type();
+ Value index;
+ for (int i = 0, e = indices.size(); i < e; ++i) {
+ Value increment = indices[i];
+ if (strides[i] != 1) { // Skip if stride is 1.
+ Value stride =
+ ShapedType::isDynamic(strides[i])
+ ? convertUnsignedToI32(rewriter, loc,
+ memRefDescriptor.stride(rewriter, loc, i))
+ : rewriter.create<LLVM::ConstantOp>(loc, i32, strides[i]);
+ increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+ }
+ index =
+ index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+ }
+ return index ? index : createI32Constant(rewriter, loc, 0);
+}
+
namespace {
// Define commonly used chipsets versions for convenience.
constexpr Chipset kGfx908 = Chipset(9, 0, 8);
@@ -88,17 +123,12 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
Type i32 = rewriter.getI32Type();
- Type llvmI32 = this->typeConverter->convertType(i32);
- Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
+ Type i16 = rewriter.getI16Type();
- auto toI32 = [&](Value val) -> Value {
- if (val.getType() == llvmI32)
- return val;
-
- return rewriter.create<LLVM::TruncOp>(loc, llvmI32, val);
- };
-
- int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
+ // Get the type size in bytes.
+ DataLayout dataLayout = DataLayout::closest(gpuOp);
+ int64_t elementByteWidth =
+ dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
// If we want to load a vector<NxT> with total size <= 32
@@ -114,7 +144,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
}
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
uint32_t vecLen = dataVector.getNumElements();
- uint32_t elemBits = dataVector.getElementTypeBitWidth();
+ uint32_t elemBits =
+ dataLayout.getTypeSizeInBits(dataVector.getElementType());
uint32_t totalBits = elemBits * vecLen;
bool usePackedFp16 =
isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
@@ -167,28 +198,37 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
MemRefDescriptor memrefDescriptor(memref);
- Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
+ Value ptr = memrefDescriptor.bufferPtr(
+ rewriter, loc, *this->getTypeConverter(), memrefType);
// The stride value is always 0 for raw buffers. This also disables
// swizling.
Value stride = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI16, rewriter.getI16IntegerAttr(0));
+ loc, i16, rewriter.getI16IntegerAttr(0));
+ // Get the number of elements.
Value numRecords;
- if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) {
- numRecords = createI32Constant(
- rewriter, loc,
- static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
+ if (memrefType.hasStaticShape() && !llvm::any_of(strides, [](int64_t v) {
+ return ShapedType::isDynamic(v);
+ })) {
+ int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+ ArrayRef<int64_t> shape = memrefType.getShape();
+ for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+ size = std::max(shape[i] * strides[i], size);
+ size = size * elementByteWidth;
+ assert(size < std::numeric_limits<uint32_t>::max() &&
+ "the memref buffer is too large");
+ numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
} else {
Value maxIndex;
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
- Value size = toI32(memrefDescriptor.size(rewriter, loc, i));
- Value stride = toI32(memrefDescriptor.stride(rewriter, loc, i));
- stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
- Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
- maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
- maxThisDim)
- : maxThisDim;
+ Value maxThisDim = rewriter.create<LLVM::MulOp>(
+ loc, memrefDescriptor.size(rewriter, loc, i),
+ memrefDescriptor.stride(rewriter, loc, i));
+ maxIndex =
+ maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+ : maxThisDim;
}
- numRecords = maxIndex;
+ numRecords = rewriter.create<LLVM::MulOp>(
+ loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
}
// Flag word:
@@ -218,40 +258,23 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
args.push_back(resource);
// Indexing (voffset)
- Value voffset = createI32Constant(rewriter, loc, 0);
- for (auto pair : llvm::enumerate(adaptor.getIndices())) {
- size_t i = pair.index();
- Value index = pair.value();
- Value strideOp;
- if (ShapedType::isDynamic(strides[i])) {
- strideOp = rewriter.create<LLVM::MulOp>(
- loc, toI32(memrefDescriptor.stride(rewriter, loc, i)),
- byteWidthConst);
- } else {
- strideOp =
- createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
- }
- index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
- voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
- }
- if (adaptor.getIndexOffset()) {
- int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
- Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
+ Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
+ adaptor.getIndices(), strides);
+ if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
+ indexOffset && *indexOffset > 0) {
+ Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
voffset =
voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
: extraOffsetConst;
}
+ voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
args.push_back(voffset);
+ // SGPR offset.
Value sgprOffset = adaptor.getSgprOffset();
if (!sgprOffset)
sgprOffset = createI32Constant(rewriter, loc, 0);
- if (ShapedType::isDynamic(offset))
- sgprOffset = rewriter.create<LLVM::AddOp>(
- loc, toI32(memrefDescriptor.offset(rewriter, loc)), sgprOffset);
- else if (offset > 0)
- sgprOffset = rewriter.create<LLVM::AddOp>(
- loc, sgprOffset, createI32Constant(rewriter, loc, offset));
+ sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
args.push_back(sgprOffset);
// bit 0: GLC = 0 (atomics drop value, less coherency)
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 4c7515dc810516..92ecbff3e691dc 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -32,14 +32,16 @@ func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_strided
func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<64xi32, strided<[?], offset: ?>>, %idx: i32) -> i32 {
- // CHECK-DAG: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
- // CHECK-DAG: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
+ // CHECK: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
+ // CHECK: %[[algn_ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[offset:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[ptr:.*]] = llvm.getelementptr %[[algn_ptr]][%[[offset]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+ // CHECK: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
// CHECK: %[[size:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: %[[size32:.*]] = llvm.trunc %[[size]] : i64 to i32
// CHECK: %[[stride:.*]] = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: %[[stride32:.*]] = llvm.trunc %[[stride]] : i64 to i32
- // CHECK: %[[tmp:.*]] = llvm.mul %[[stride32]], %[[elem_size]] : i32
- // CHECK: %[[numRecords:.*]] = llvm.mul %[[size32]], %[[tmp]] : i32
+ // CHECK: %[[tmp:.*]] = llvm.mul %[[size]], %[[stride]] : i64
+ // CHECK: %[[num_elem:.*]] = llvm.trunc %[[tmp]] : i64 to i32
+ // CHECK: %[[numRecords:.*]] = llvm.mul %[[num_elem]], %[[elem_size]] : i32
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[rstride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>
``````````
</details>
https://github.com/llvm/llvm-project/pull/122293
More information about the Mlir-commits
mailing list