[Mlir-commits] [mlir] 73979c1 - [mlir][amdgpu] Lower amdgpu.make_dma_base (#169817)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 2 10:48:35 PST 2025


Author: Erick Ochoa Lopez
Date: 2025-12-02T13:48:31-05:00
New Revision: 73979c1df9695f281d78ad8e18a7023bcbbceab9

URL: https://github.com/llvm/llvm-project/commit/73979c1df9695f281d78ad8e18a7023bcbbceab9
DIFF: https://github.com/llvm/llvm-project/commit/73979c1df9695f281d78ad8e18a7023bcbbceab9.diff

LOG: [mlir][amdgpu] Lower amdgpu.make_dma_base (#169817)

* Adds lowering for `amdgpu.make_dma_base`

Added: 
    mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir

Modified: 
    mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
    mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
    mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
    mlir/test/Dialect/AMDGPU/invalid.mlir
    mlir/test/Dialect/AMDGPU/ops.mlir

Removed: 
    mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 3581b07dc4e3e..16eaf28ddd95b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1229,15 +1229,13 @@ def AMDGPU_ScaledMFMAOp :
 
 def AMDGPU_MakeDmaBaseOp :
     AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
-    Arguments<(ins
-                   Arg<AnyMemRef, "buffer to read from">:$src,
-                   Variadic<Index>:$src_indices,
-                   Arg<AnyMemRef, "buffer to write to">:$dst,
-                   Variadic<Index>:$dst_indices)>,
+    Arguments<(ins Arg<AnyMemRef>:$global,
+                   Variadic<Index>:$global_indices,
+                   Arg<AnyMemRef>:$lds,
+                   Variadic<Index>:$lds_indices)>,
     Results<(outs AMDGPU_TDMBaseType: $base)> {
 
   // TODO:
-  // * Add verifiers such that one of the memrefs is from LDS and the other global.
   // * Add verifiers to make sure that the number of indices do not exceed the number of dimensions.
 
   let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
@@ -1251,7 +1249,7 @@ def AMDGPU_MakeDmaBaseOp :
     For example:
 
     ```mlir
-      %base = amdgpu.make_dma_base %src[%idx0], %dst[%idx1] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
+      %base = amdgpu.make_dma_base %global[%idx0, %idx1], %lds[%idx2, %idx3] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
       %descriptor = amdgpu.make_dma_descriptor %base globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
       amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
     ```
@@ -1259,27 +1257,31 @@ def AMDGPU_MakeDmaBaseOp :
     to
 
     ```mlir
-       // pseudocode
-       %base_0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr)>
-       %base_1 = llvm.insertvalue %global_addr, %base_0[0] : !llvm.struct<(ptr, ptr)>
-       %base_2 = llvm.insertvalue %lds_addr, %base_1[1] : !llvm.struct(ptr, ptr)>
-       // type(%base_2) = !llvm.struct<(ptr, ptr) roughly corresponds to amdgpu.tdm_base<i32>
-
-       // The base will be used when contructing dgroup0
-       // when lowering amdgpu.make_dma_descriptor
-       %dgroup0_0 = llvm.mlir.undef : !llvm.struct<(....)>
-       %dgroup0_1 = llvm.insertvalue %base2, %dgroup0_0 : ....
-
-       // When lowering amdgpu.tensor_load_to_lds
-       rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+      // pseudo-code
+      %global_base = llvm.extractvalue %global_memref[1]
+      %global_address = llvm.get_element_ptr ...
+
+      %lds_base = llvm.extractvalue %lds_memref[1]
+      %lds_address = llvm.get_element_ptr ...
+
+      // Definition of %base
+      %undef = llvm.mlir.undef : vector<4xi32>
+      %v0 = llvm.insertelement %15, %undef[0] : vector<4xi32>
+      %v1 = llvm.insertelement %lds_address, %v0[1] : vector<4xi32>
+      %v2 = llvm.insertelement %global_address_low, %v1[2] : vector<4xi32>
+      %base = llvm.insertelement %global_address_high, %v2[3] : vector<4xi32>
+
+      rocdl.tensor.load.to.lds %base, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
     ```
 
     These tensor DMA operations were introduced in gfx1250.
   }];
 
   let assemblyFormat = [{
-    $src `[` $src_indices `]` `,` $dst `[` $dst_indices `]` attr-dict `:` type($src) `,` type($dst) `->` type(results)
+    $global `[` $global_indices `]` `,` $lds `[` $lds_indices `]` attr-dict `:` type($global) `,` type($lds) `->` type(results)
   }];
+
+  let hasVerifier = 1;
 }
 
 def AMDGPU_MakeDmaDescriptorOp :
@@ -1323,12 +1325,12 @@ def AMDGPU_MakeDmaDescriptorOp :
 
      ```mlir
       // Example of moving a two-dimensional tensor to LDS.
-      %base = amdgpu.make_dma_base %src[0, 0], %dst[0, 0] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
+      %base = amdgpu.make_dma_base %global[0, 0], %lds[0, 0] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
       %descriptor = amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
       amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
 
       // Example of moving a two dimension tensor to LDS where padding is applied after every integer.
-      %base = amdgpu.make_dma_base %src[0, 0], %dst[0, 0] : memref<32x32xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
+      %base = amdgpu.make_dma_base %global[0, 0], %lds[0, 0] : memref<32x32xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
       %descriptor = amdgpu.make_dma_descriptor %base globalSize [32, 32] globalStride [32, 1] sharedSize [64, 64] padding(%pad pad_every %pad_every) : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
       amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
      ```

diff  --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b9a5e7d7f6eac..2b6938712dad2 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2264,6 +2264,77 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
   }
 };
 
+struct AMDGPUMakeDmaBaseLowering
+    : public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx1250)
+      return op->emitOpError("make_dma_base is only supported on gfx1250");
+
+    Location loc = op.getLoc();
+
+    ValueRange ldsIndices = adaptor.getLdsIndices();
+    Value lds = adaptor.getLds();
+    auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
+
+    Value ldsPtr =
+        getStridedElementPtr(rewriter, loc, ldsMemRefType, lds, ldsIndices);
+
+    ValueRange globalIndices = adaptor.getGlobalIndices();
+    Value global = adaptor.getGlobal();
+    auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
+
+    Value globalPtr = getStridedElementPtr(rewriter, loc, globalMemRefType,
+                                           global, globalIndices);
+
+    Type i32 = rewriter.getI32Type();
+    Type i64 = rewriter.getI64Type();
+
+    Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
+    Value castForGlobalAddr =
+        LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
+
+    Value lowHalf =
+        LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
+
+    Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
+                                       createI64Constant(rewriter, loc, 32));
+
+    Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
+
+    Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
+    Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
+
+    Value typeField = createI32Constant(rewriter, loc, 2 << 30);
+    Value highHalfPlusType =
+        LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField);
+
+    Value c0 = createI32Constant(rewriter, loc, 0);
+    Value c1 = createI32Constant(rewriter, loc, 1);
+    Value c2 = createI32Constant(rewriter, loc, 2);
+    Value c3 = createI32Constant(rewriter, loc, 3);
+
+    Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
+    Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
+    result = LLVM::InsertElementOp::create(rewriter, loc, result, c1, c0);
+    result = LLVM::InsertElementOp::create(rewriter, loc, result,
+                                           castForLdsAddr, c1);
+    result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2);
+    result = LLVM::InsertElementOp::create(rewriter, loc, result,
+                                           highHalfPlusType, c3);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct ConvertAMDGPUToROCDLPass
     : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
   using Base::Base;
@@ -2278,6 +2349,10 @@ struct ConvertAMDGPUToROCDLPass
 
     RewritePatternSet patterns(ctx);
     LLVMTypeConverter converter(ctx);
+    converter.addConversion([&](TDMBaseType type) -> Type {
+      Type i32 = IntegerType::get(type.getContext(), 32);
+      return converter.convertType(VectorType::get(4, i32));
+    });
     populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
     LLVMConversionTarget target(getContext());
     target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
@@ -2333,6 +2408,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
            PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
            GatherToLDSOpLowering, TransposeLoadOpLowering,
-           AMDGPUPermlaneLowering>(converter, chipset);
+           AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering>(converter,
+                                                              chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 93cb9b38a5ecf..8b58c3b1dd182 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -705,6 +705,24 @@ LogicalResult TransposeLoadOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// MakeDmaBaseOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult MakeDmaBaseOp::verify() {
+  MemRefType ldsType = cast<MemRefType>(getLds().getType());
+  MemRefType globalType = cast<MemRefType>(getGlobal().getType());
+  if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace())) {
+    return emitOpError(
+        "lds memref must have workgroup address space attribute.");
+  }
+  if (!hasGlobalMemorySpace(globalType.getMemorySpace())) {
+    return emitOpError(
+        "global memref must have global address space attribute.");
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // MakeDmaDescriptorOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
similarity index 81%
rename from mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
rename to mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
index d2391140ce056..27daea58f8f92 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
@@ -162,3 +162,51 @@ func.func @amdgpu.scaled_ext_packed816_invalid_dst_elem_type(%v: vector<16xf6E3M
   %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf64>
   return %ret0: vector<16xf64>
 }
+
+// -----
+
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+#amdgpu_fat_buffer_addrspace = 7
+
+// CHECK-LABEL: func @make_dma_base
+// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32, 1>, %[[SMEM:.+]]: memref<8xi32, 3>)
+func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>, %smem: memref<8xi32,#gpu_lds_addrspace>) -> (!amdgpu.tdm_base<i32>) {
+  // CHECK-DAG: %[[INT:.+]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+  // CHECK-DAG: %[[MEMREF_DESC_MEM:.+]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<8xi32, 1>
+  // CHECK-DAG: %[[MEMREF_DESC_SMEM:.+]] = builtin.unrealized_conversion_cast %[[SMEM]] : memref<8xi32, 3>
+
+  // CHECK-DAG: %[[MEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_MEM]][1] : !llvm.struct<(ptr<1>
+  // CHECK-DAG: %[[SMEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_SMEM]][1] : !llvm.struct<(ptr<3>
+
+  // CHECK-DAG: %[[MEM_BASE_OFFSET:.+]] = llvm.getelementptr %[[MEM_BASE_PTR]][%[[INT]]]
+  // CHECK-DAG: %[[SMEM_BASE_OFFSET:.+]] = llvm.getelementptr %[[SMEM_BASE_PTR]][%[[INT]]]
+
+  // CHECK-DAG: %[[MEM_INT:.+]] = llvm.ptrtoint %[[MEM_BASE_OFFSET]] : !llvm.ptr<1> to i64
+  // CHECK-DAG: %[[SMEM_INT:.+]] = llvm.ptrtoint %[[SMEM_BASE_OFFSET]] : !llvm.ptr<3> to i32
+
+  // CHECK: %[[MEM_INT_LOW:.+]] = llvm.trunc %[[MEM_INT]] : i64 to i32
+  // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(32 : i64)
+  // CHECK: %[[SHIFTED_MEM_INT:.+]] = llvm.lshr %[[MEM_INT]], %[[SHIFT]]
+  // CHECK: %[[MEM_INT_HIGH:.+]] = llvm.trunc %[[SHIFTED_MEM_INT]] : i64 to i32
+  // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(33554431 : i32)
+  // CHECK: %[[VALID_MEM_INT_HIGH:.+]] = llvm.and %[[MEM_INT_HIGH]], %[[MASK]]
+
+  // CHECK-DAG: %[[TYPE_FIELD:.+]] = llvm.mlir.constant(-2147483648 : i32)
+  // CHECK: %[[MEM_INT_HIGH_TYPE:.+]] = llvm.or %[[VALID_MEM_INT_HIGH]], %[[TYPE_FIELD]]
+
+  // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+  // CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32>
+  // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[C1]], %[[V4I32_0_0]][%[[C0]] : i32]
+  // CHECK: %[[V4I32_0_2:.+]] = llvm.insertelement %[[SMEM_INT]], %[[V4I32_0_1]][%[[C1]] : i32]
+  // CHECK: %[[V4I32_0_3:.+]] = llvm.insertelement %[[MEM_INT_LOW]], %[[V4I32_0_2]][%[[C2]] : i32]
+  // CHECK: %[[V4I32_0_4:.+]] = llvm.insertelement %[[MEM_INT_HIGH_TYPE]], %[[V4I32_0_3]][%[[C3]] : i32]
+
+  %0 = amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_base<i32>
+
+  func.return %0 : !amdgpu.tdm_base<i32>
+}

diff  --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 5b3a79d14cb1a..b915bfa324c77 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -357,6 +357,20 @@ func.func @scaled_mfma_invalid_k(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32x
 
 // -----
 
+func.func @make_dma_base_invalid_addressspace(%idx: index, %mem: memref<8xi32>) {
+  // expected-error at +1 {{'amdgpu.make_dma_base' op lds memref must have workgroup address space attribute.}}
+  amdgpu.make_dma_base %mem[%idx], %mem[%idx] : memref<8xi32>, memref<8xi32> -> !amdgpu.tdm_base<i32>
+}
+
+// -----
+
+func.func @make_dma_base_invalid_addressspace(%idx: index, %smem : memref<8xi32, #gpu.address_space<workgroup>>) {
+  // expected-error at +1 {{'amdgpu.make_dma_base' op global memref must have global address space attribute.}}
+  amdgpu.make_dma_base %smem[%idx], %smem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
+}
+
+// -----
+
 func.func @make_dma_base_invalid_barrier(%base: !amdgpu.tdm_base<i32>, %barrier: memref<8xi32>, %idx: index) {
   // expected-error at +1 {{'amdgpu.make_dma_descriptor' op atomic barrier address must be in LDS.}}
   amdgpu.make_dma_descriptor %base globalSize [0] globalStride [1] sharedSize [0] atomicBarrier(%barrier[%idx] : memref<8xi32>) : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor

diff  --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 390ad8cb8c1a5..3260bd4a8df9a 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -691,9 +691,6 @@ func.func @memory_counter_wait() {
 func.func @make_dma_base(%idx: index, %mem: memref<8xi32>, %smem: memref<8xi32, #gpu.address_space<workgroup>>) {
   // CHECK: amdgpu.make_dma_base %[[MEM]][%[[IDX]]], %[[SMEM]][%[[IDX]]] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
   amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
-
-  // CHECK: amdgpu.make_dma_base %[[SMEM]][%[[IDX]]], %[[MEM]][%[[IDX]]] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32> -> !amdgpu.tdm_base<i32>
-  amdgpu.make_dma_base %smem[%idx], %mem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32> -> !amdgpu.tdm_base<i32>
   func.return
 }
 
@@ -748,3 +745,4 @@ func.func @make_dma_descriptor(%base: !amdgpu.tdm_base<i32>, %barrier: memref<8x
 
   func.return
 }
+


        


More information about the Mlir-commits mailing list