[Mlir-commits] [mlir] [mlir][amdgpu] Add lowering for make_dma_descriptor (PR #169955)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Tue Dec 2 13:19:30 PST 2025
================
@@ -2264,6 +2264,451 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
}
};
+struct AMDGPUMakeDmaBaseLowering
+ : public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset < kGfx1250) {
+ return op->emitOpError("make_dma_base is only supported on gfx1250");
+ }
+
+ Location loc = op.getLoc();
+
+ ValueRange srcIndices = adaptor.getSrcIndices();
+ Value src = adaptor.getSrc();
+ auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+
+ Value srcPtr =
+ getStridedElementPtr(rewriter, loc, srcMemRefType, src, srcIndices);
+
+ ValueRange dstIndices = adaptor.getDstIndices();
+ Value dst = adaptor.getDst();
+ auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
+
+ Value dstPtr =
+ getStridedElementPtr(rewriter, loc, dstMemRefType, dst, dstIndices);
+
+ bool storeFrom = hasWorkgroupMemorySpace(srcMemRefType.getMemorySpace());
+ Value ldsAddr = storeFrom ? srcPtr : dstPtr;
+ Value globalAddr = storeFrom ? dstPtr : srcPtr;
+
+ Type i32 = rewriter.getI32Type();
+ Type i64 = rewriter.getI64Type();
+
+ Value castForLdsAddr =
+ LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsAddr);
+ Value castForGlobalAddr =
+ LLVM::PtrToIntOp::create(rewriter, loc, i64, globalAddr);
+
+ Value mask = createI64Constant(rewriter, loc, 0x1FFFFFFFFFFFFFF);
+ Value first57BitsOfGlobalAddr =
+ LLVM::AndOp::create(rewriter, loc, castForGlobalAddr, mask);
+ Value shift = LLVM::LShrOp::create(rewriter, loc, first57BitsOfGlobalAddr,
+ createI64Constant(rewriter, loc, 32));
+
+ Value lowHalf =
+ LLVM::TruncOp::create(rewriter, loc, i32, first57BitsOfGlobalAddr);
+ Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
+
+ Value typeMask = createI32Constant(rewriter, loc, 2 << 30);
+ Value highHalfPlusType =
+ LLVM::OrOp::create(rewriter, loc, highHalf, typeMask);
+
+ Value c0 = createI32Constant(rewriter, loc, 0);
+ Value c1 = createI32Constant(rewriter, loc, 1);
+ Value c2 = createI32Constant(rewriter, loc, 2);
+ Value c3 = createI32Constant(rewriter, loc, 3);
+
+ Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
+ Value result = LLVM::UndefOp::create(rewriter, loc, v4i32);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result, c0, c0);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result,
+ castForLdsAddr, c1);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result,
+ highHalfPlusType, c3);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+struct AMDGPUMakeDmaDescriptorLowering
+ : public ConvertOpToLLVMPattern<MakeDmaDescriptorOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ AMDGPUMakeDmaDescriptorLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<MakeDmaDescriptorOp>(converter),
+ chipset(chipset) {}
+ Chipset chipset;
+
+ Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
+
+ Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
+ Value accumulator, Value value, int shift) const {
+ shift = shift % 32;
+ Value shiftAmount;
+ if (shift != 0) {
+ shiftAmount = createI32Constant(rewriter, loc, shift % 32);
+ value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
+ }
+ return LLVM::OrOp::create(rewriter, loc, accumulator, value);
+ }
+
+ Value setDataSize(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, const SmallVector<Value> consts) const {
+ // Compute data_size.
+ int elementTypeWidthInBytes = op.getElementTypeWidth() / 8;
+
+ Value dataSize;
+ switch (elementTypeWidthInBytes) {
+ case 1:
+ dataSize = consts[0];
+ break;
+ case 2:
+ dataSize = consts[1];
+ break;
+ case 4:
+ dataSize = consts[2];
+ break;
+ case 8:
+ dataSize = consts[3];
+ break;
+ default:
+ llvm_unreachable("Invalid element size.");
+ }
+ return setValueAtOffset(rewriter, loc, sgpr0, dataSize, 16);
+ }
+
+ Value setAtomicBarrier(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, const SmallVector<Value> &consts) const {
+ bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr;
+ if (!atomic_barrier_enable)
+ return sgpr0;
+
+ return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
+ }
+
+ Value setIterateEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, const SmallVector<Value> &consts) const {
+ bool iterate_enable = adaptor.getGlobalIncrement() != nullptr;
+ if (!iterate_enable)
+ return sgpr0;
+
+ return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
+ }
+
+ Value setPadEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, const SmallVector<Value> &consts) const {
+ bool pad_enable = op.getPadAmount() != nullptr;
+ if (!pad_enable)
+ return sgpr0;
+
+ return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
+ }
+
+ Value setPadInterval(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, const SmallVector<Value> &consts) const {
+ bool pad_enable = op.getPadAmount() != nullptr;
+ if (!pad_enable)
+ return sgpr0;
+
+ IntegerType i32 = rewriter.getI32Type();
+ Value padInterval = adaptor.getPadInterval();
+ // pre-condition: padInterval can be a power of two between 2 and 256
+ padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
+ padInterval, false);
+ padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
+ // post-condition: padInterval can be a value between 0 and 7
+ return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
+ }
+
+ Value setPadAmount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, const SmallVector<Value> &consts) const {
+ bool pad_enable = op.getPadAmount() != nullptr;
+ if (!pad_enable)
+ return sgpr0;
+
+ Value padAmount = adaptor.getPadAmount();
+ // pre-condition: padAmount is a value between 1-128
+ padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
+ // post-condition: padAmount is a value between 0-127
+ return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
+ }
+
+ Value setAtomicBarrierAddress(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Value sgpr1,
+ const SmallVector<Value> &consts) const {
+ bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr;
+ if (!atomic_barrier_enable)
+ return sgpr1;
+
+ Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
+ IntegerType i32 = rewriter.getI32Type();
+ atomicBarrierAddress =
+ LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
+ atomicBarrierAddress =
+ LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
+ Value mask = createI32Constant(rewriter, loc, 0xFFFF);
+ atomicBarrierAddress =
+ LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
+ return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
+ }
+
+ std::pair<Value, Value>
+ setTensorDim0(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc, Value sgpr1,
+ Value sgpr2, const SmallVector<Value> &consts) const {
+ SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
+ OpFoldResult tensorDim0OpFoldResult = mixedGlobalSizes.back();
+ Value tensorDim0;
+ if (auto attr = dyn_cast<Attribute>(tensorDim0OpFoldResult)) {
+ tensorDim0 =
+ createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+ } else {
+ tensorDim0 = cast<Value>(tensorDim0OpFoldResult);
+ }
+ Value c16 = createI32Constant(rewriter, loc, 16);
+ Value tensorDim0High = LLVM::LShrOp::create(rewriter, loc, tensorDim0, c16);
+ sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDim0, 48);
+ sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim0High, 48 + 16);
+ return {sgpr1, sgpr2};
+ }
+
+ std::pair<Value, Value>
+ setTensorDim1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc, Value sgpr2,
+ Value sgpr3, const SmallVector<Value> &consts) const {
+ SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
+ OpFoldResult tensorDim1OpFoldResult = *(mixedGlobalSizes.rbegin() + 1);
+ Value tensorDim1;
+ if (auto attr = dyn_cast<Attribute>(tensorDim1OpFoldResult)) {
+ tensorDim1 =
+ createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+ } else {
+ tensorDim1 = cast<Value>(tensorDim1OpFoldResult);
+ }
+ Value c16 = createI32Constant(rewriter, loc, 16);
+ Value tensorDim1High = LLVM::LShrOp::create(rewriter, loc, tensorDim1, c16);
+ sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim1, 80);
+ sgpr3 = setValueAtOffset(rewriter, loc, sgpr3, tensorDim1High, 80 + 16);
+ return {sgpr2, sgpr3};
+ }
+
+ Value setTileDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr, const SmallVector<Value> &consts, size_t dimX,
+ int offset) const {
+ SmallVector<OpFoldResult> mixedSharedSizes = op.getMixedSharedSizes();
+
+ if (mixedSharedSizes.size() <= dimX) {
+ return sgpr;
+ }
+
+ OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
+ Value tileDimX;
+ if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
+ tileDimX =
+ createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+ } else {
+ tileDimX = cast<Value>(tileDimXOpFoldResult);
+ }
+ return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
+ }
+
+ Value setTileDim0(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr3, const SmallVector<Value> &consts) const {
+ return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
+ }
+
+ Value setTileDim1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr4, const SmallVector<Value> &consts) const {
+ return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
+ }
+
+ Value setTileDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr4, const SmallVector<Value> &consts) const {
+ return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
+ }
+
+ std::pair<Value, Value>
+ setTensorDimXStride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgprY, Value sgprZ,
+ const SmallVector<Value> &consts, size_t dimX,
+ int offset) const {
+ SmallVector<OpFoldResult> mixedGlobalStrides = op.getMixedGlobalStrides();
+
+ if (mixedGlobalStrides.size() <= dimX) {
+ return {sgprY, sgprZ};
+ }
+
+ OpFoldResult tensorDimXStrideOpFoldResult =
+ *(mixedGlobalStrides.rbegin() + dimX);
+ Value tensorDimXStride;
+ if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult)) {
+ tensorDimXStride =
+ createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+ } else {
+ tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
+ }
+
+ constexpr int64_t first48bits = 0xFFFFFFFFFFFF;
+ Value mask = createI64Constant(rewriter, loc, first48bits);
+ tensorDimXStride =
+ LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
+ IntegerType i32 = rewriter.getI32Type();
+ Value tensorDimXStrideLow =
+ LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
+
+ int shift = (offset % 32) == 0 ? 32 : offset % 32;
+ Value shiftVal = createI64Constant(rewriter, loc, shift);
+ Value tensorDimXStrideHigh =
+ LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
+ tensorDimXStrideHigh =
+ LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
+
+ sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
+ sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
+ offset + shift);
+ return {sgprY, sgprZ};
+ }
+
+ std::pair<Value, Value>
+ setTensorDim0Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr5, Value sgpr6,
+ const SmallVector<Value> &consts) const {
+ return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
+ 0, 160);
+ }
+
+ std::pair<Value, Value>
+ setTensorDim1Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr5, Value sgpr6,
+ const SmallVector<Value> &consts) const {
+ return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
+ 1, 208);
+ }
+
+ Value getDGroup1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ const SmallVector<Value> &consts) const {
+
+ Value sgpr0, sgpr1, sgpr2, sgpr3, sgpr4, sgpr5, sgpr6, sgpr7;
+ sgpr0 = sgpr1 = sgpr2 = sgpr3 = sgpr4 = sgpr5 = sgpr6 = sgpr7 = consts[0];
+
+ sgpr0 = setDataSize(op, adaptor, rewriter, loc, sgpr0, consts);
----------------
amd-eochoalo wrote:
You mean like this? https://github.com/llvm/llvm-project/pull/169955/commits/f187e76c255c7fede998ce98bc7675df93ba8646 https://github.com/llvm/llvm-project/pull/169955/commits/661931c2ff126b955dc9f16c9a4836b12c3eccd6
https://github.com/llvm/llvm-project/pull/169955
More information about the Mlir-commits
mailing list