[Mlir-commits] [mlir] [mlir][amdgpu] Adds make_dma_gather_base (PR #171857)

Erick Ochoa Lopez llvmlistbot at llvm.org
Thu Dec 11 07:58:58 PST 2025


https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/171857

* Adds `tdm_gather_base` type.
* Adds `make_dma_gather_base` op.
* Adds `make_dma_gather_base` lowering to ROCDL.

>From cf7c396632c7d991fa4d7b613f5e68542629e904 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 10 Dec 2025 11:58:52 -0500
Subject: [PATCH 1/9] [mlir][amdgpu] Add make_gather_dma_base.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  60 ++++++++--
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 104 ++++++++++--------
 .../Conversion/AMDGPUToROCDL/gfx1250.mlir     |  13 ++-
 3 files changed, 120 insertions(+), 57 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 56160d3e8fe85..82398c2c82200 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -94,6 +94,9 @@ def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> {
   let description = [{
     This type is opaque and it is used to represent a struct of two addresses.
     One address is in LDS while the other is in global memory.
+
+    The value defined by this operation is only intended to be used by
+    amdgpu.tdm_make_descriptor.
   }];
   let parameters = (ins "Type":$elementType);
   let builders = [
@@ -104,6 +107,27 @@ def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> {
   let assemblyFormat = "`<` $elementType `>`";
 }
 
+def AMDGPU_TDMGatherBaseType : AMDGPU_Type<"TDMGatherBase", "tdm_gather_base"> {
+  let summary = "Pair of base addresses that move data between LDS and global storage.";
+  let description = [{
+    This type is opaque and it is used to represent a struct of two addresses.
+    One address is in LDS while the other is in global memory.
+
+    This operation is similar to amdgpu.tdm_make_base but intended to be
+    used in gather mode.
+
+    The value defined by this operation is only intended to be used by
+    amdgpu.tdm_make_gather_descriptor.
+  }];
+  let parameters = (ins "Type":$elementType, "unsigned":$indexSize);
+  let builders = [
+    TypeBuilderWithInferredContext<(ins "Type":$elementType, "unsigned": $indexSize), [{
+      return $_get(elementType.getContext(), elementType, indexSize);
+    }]>
+  ];
+  let assemblyFormat = "`<` $elementType `,` $indexSize `>`";
+}
+
 def AMDGPU_TDMDescriptorType : AMDGPU_Type<"TDMDescriptor", "tdm_descriptor"> {
   let summary = "Descriptors used in tensor store/load operations.";
   let description = [{
@@ -1229,17 +1253,37 @@ def AMDGPU_ScaledMFMAOp :
   let hasCanonicalizer = 1;
 }
 
-def AMDGPU_MakeDmaBaseOp :
-    AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments, AllElementTypesMatch<["global", "lds"]>]>,
+
+class AMDGPU_DmaBaseOp<string mnemonic, Type outType> :
+    AMDGPU_Op<mnemonic, [Pure, AttrSizedOperandSegments, AllElementTypesMatch<["global", "lds"]>]>,
     Arguments<(ins Arg<AnyMemRef>:$global,
                    Variadic<Index>:$global_indices,
                    Arg<AnyMemRef>:$lds,
                    Variadic<Index>:$lds_indices)>,
-    Results<(outs AMDGPU_TDMBaseType: $base)> {
+    Results<(outs outType: $base)> {
 
   // TODO:
   // * Add verifiers to make sure that the number of indices do not exceed the number of dimensions.
 
+  let assemblyFormat = [{
+    $global `[` $global_indices `]` `,` $lds `[` $lds_indices `]` attr-dict `:` type($global) `,` type($lds) `->` type(results)
+  }];
+}
+
+def AMDGPU_MakeGatherDmaBaseOp : AMDGPU_DmaBaseOp<"make_gather_dma_base", AMDGPU_TDMGatherBaseType> {
+  let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
+  let description = "TODO";
+
+  let extraClassDeclaration = [{
+    constexpr int64_t isGather() {
+      return true;
+    }
+  }];
+}
+
+
+def AMDGPU_MakeDmaBaseOp : AMDGPU_DmaBaseOp<"make_dma_base", AMDGPU_TDMBaseType> {
+
   let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
   let description = [{
     This operation creates a pair of addresses that will be used by tensor_load_to_lds
@@ -1279,11 +1323,13 @@ def AMDGPU_MakeDmaBaseOp :
     These tensor DMA operations were introduced in gfx1250.
   }];
 
-  let assemblyFormat = [{
-    $global `[` $global_indices `]` `,` $lds `[` $lds_indices `]` attr-dict `:` type($global) `,` type($lds) `->` type(results)
-  }];
-
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    constexpr int64_t isGather() {
+      return false;
+    }
+  }];
 }
 
 def AMDGPU_MakeDmaDescriptorOp :
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 91154b846f567..55c68ba73c19b 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2276,45 +2276,84 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
   }
 };
 
-struct AMDGPUMakeDmaBaseLowering
-    : public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
-  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
+                              Value accumulator, Value value, int64_t shift) {
+  shift = shift % 32;
+  Value shiftAmount;
+  if (shift != 0) {
+    shiftAmount = createI32Constant(rewriter, loc, shift % 32);
+    value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
+  }
+
+  if (matchPattern(accumulator, mlir::m_Zero()))
+    return value;
+
+  return LLVM::OrOp::create(rewriter, loc, accumulator, value);
+}
+
+template <typename BaseOp>
+struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
+  using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
+  using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
 
   AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
-      : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
+      : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
   Chipset chipset;
 
   LogicalResult
-  matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
+  matchAndRewrite(BaseOp op, Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (chipset < kGfx1250)
       return op->emitOpError("make_dma_base is only supported on gfx1250");
 
     Location loc = op.getLoc();
 
+    constexpr int32_t constlen = 4;
+    Value consts[constlen];
+    for (int64_t i = 0; i < constlen; i++)
+      consts[i] = createI32Constant(rewriter, loc, i);
+
+    constexpr int32_t sgprslen = constlen;
+    Value sgprs[sgprslen];
+    for (int64_t i = 0; i < sgprslen; i++) {
+      sgprs[i] = consts[0];
+    }
+
+    sgprs[0] = consts[1];
+
+    if (op.isGather()) {
+      sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
+
+      auto type = cast<TDMGatherBaseType>(op.getResult().getType());
+      unsigned indexSize = type.getIndexSize();
+      assert(llvm::is_contained<unsigned>({16, 32}, indexSize) &&
+             "expected index_size to be 16 or 32");
+      unsigned idx = (indexSize / 16) - 1;
+      sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[idx], 31);
+    }
+
     ValueRange ldsIndices = adaptor.getLdsIndices();
     Value lds = adaptor.getLds();
     auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
 
-    Value ldsPtr =
-        getStridedElementPtr(rewriter, loc, ldsMemRefType, lds, ldsIndices);
+    Value ldsPtr = ConvertToLLVMPattern::getStridedElementPtr(
+        rewriter, loc, ldsMemRefType, lds, ldsIndices);
 
     ValueRange globalIndices = adaptor.getGlobalIndices();
     Value global = adaptor.getGlobal();
     auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
 
-    Value globalPtr = getStridedElementPtr(rewriter, loc, globalMemRefType,
-                                           global, globalIndices);
+    Value globalPtr = ConvertToLLVMPattern::getStridedElementPtr(
+        rewriter, loc, globalMemRefType, global, globalIndices);
 
     Type i32 = rewriter.getI32Type();
     Type i64 = rewriter.getI64Type();
 
-    Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
+    sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
     Value castForGlobalAddr =
         LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
 
-    Value lowHalf =
-        LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
+    sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
 
     Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
                                        createI64Constant(rewriter, loc, 32));
@@ -2322,26 +2361,17 @@ struct AMDGPUMakeDmaBaseLowering
     Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
 
     Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
-    Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
-
-    Value typeField = createI32Constant(rewriter, loc, 2 << 30);
-    Value highHalfPlusType =
-        LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField);
+    highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
 
-    Value c0 = createI32Constant(rewriter, loc, 0);
-    Value c1 = createI32Constant(rewriter, loc, 1);
-    Value c2 = createI32Constant(rewriter, loc, 2);
-    Value c3 = createI32Constant(rewriter, loc, 3);
+    sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
 
     Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
     assert(v4i32 && "expected type conversion to succeed");
     Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
-    result = LLVM::InsertElementOp::create(rewriter, loc, result, c1, c0);
-    result = LLVM::InsertElementOp::create(rewriter, loc, result,
-                                           castForLdsAddr, c1);
-    result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2);
-    result = LLVM::InsertElementOp::create(rewriter, loc, result,
-                                           highHalfPlusType, c3);
+
+    for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
+      result =
+          LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
 
     rewriter.replaceOp(op, result);
     return success();
@@ -2360,21 +2390,6 @@ struct AMDGPUMakeDmaDescriptorLowering
 
   Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
 
-  Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
-                         Value accumulator, Value value, int64_t shift) const {
-    shift = shift % 32;
-    Value shiftAmount;
-    if (shift != 0) {
-      shiftAmount = createI32Constant(rewriter, loc, shift % 32);
-      value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
-    }
-
-    if (matchPattern(accumulator, mlir::m_Zero()))
-      return value;
-
-    return LLVM::OrOp::create(rewriter, loc, accumulator, value);
-  }
-
   Value setWorkgroupMask(MakeDmaDescriptorOp op, OpAdaptor adaptor,
                          ConversionPatternRewriter &rewriter, Location loc,
                          Value sgpr0) const {
@@ -2797,7 +2812,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
       ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
       PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
       GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
-      AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
-                                                                  chipset);
+      AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
+      AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
+      AMDGPUMakeDmaDescriptorLowering>(converter, chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
index a94e17ab5b9a5..5c433b5c42bb9 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
@@ -200,6 +200,11 @@ func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>
   // CHECK-DAG: %[[MEMREF_DESC_MEM:.+]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<8xi32, 1>
   // CHECK-DAG: %[[MEMREF_DESC_SMEM:.+]] = builtin.unrealized_conversion_cast %[[SMEM]] : memref<8xi32, 3>
 
+  // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
   // CHECK-DAG: %[[MEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_MEM]][1] : !llvm.struct<(ptr<1>
   // CHECK-DAG: %[[SMEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_SMEM]][1] : !llvm.struct<(ptr<3>
 
@@ -216,14 +221,10 @@ func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>
   // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(33554431 : i32)
   // CHECK: %[[VALID_MEM_INT_HIGH:.+]] = llvm.and %[[MEM_INT_HIGH]], %[[MASK]]
 
-  // CHECK-DAG: %[[TYPE_FIELD:.+]] = llvm.mlir.constant(-2147483648 : i32)
+  // CHECK: %[[SHIFT:.+]] = llvm.mlir.constant(30 : i32)
+  // CHECK: %[[TYPE_FIELD:.+]] = llvm.shl %[[C2]], %[[SHIFT]]
   // CHECK: %[[MEM_INT_HIGH_TYPE:.+]] = llvm.or %[[VALID_MEM_INT_HIGH]], %[[TYPE_FIELD]]
 
-  // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
-  // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
-
   // CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32>
   // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[C1]], %[[V4I32_0_0]][%[[C0]] : i32]
   // CHECK: %[[V4I32_0_2:.+]] = llvm.insertelement %[[SMEM_INT]], %[[V4I32_0_1]][%[[C1]] : i32]

>From f2ebb8142657a0d2aac1fd377df56f52a5ba1da8 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 10 Dec 2025 16:14:22 -0500
Subject: [PATCH 2/9] Add test

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  4 +-
 .../Conversion/AMDGPUToROCDL/gfx1250.mlir     | 45 +++++++++++++++++++
 2 files changed, 48 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 55c68ba73c19b..ff4798961ef73 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2329,7 +2329,9 @@ struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
       assert(llvm::is_contained<unsigned>({16, 32}, indexSize) &&
              "expected index_size to be 16 or 32");
       unsigned idx = (indexSize / 16) - 1;
-      sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[idx], 31);
+
+      if (idx)
+        sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[idx], 31);
     }
 
     ValueRange ldsIndices = adaptor.getLdsIndices();
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
index 5c433b5c42bb9..4cbfec52cf0a7 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
@@ -238,6 +238,51 @@ func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>
 
 // -----
 
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+
+// CHECK-LABEL: func @make_gather_dma_base
+// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32, 1>, %[[SMEM:.+]]: memref<8xi32, 3>)
+func.func @make_gather_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>, %smem: memref<8xi32,#gpu_lds_addrspace>) -> (!amdgpu.tdm_gather_base<i32,16>, !amdgpu.tdm_gather_base<i32, 32>) {
+
+  // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+  // CHECK-DAG: %[[GATHER_MODE_OFFSET:.+]] = llvm.mlir.constant(30 : i32) : i32
+  // CHECK-DAG: %[[GATHER_MODE_BIT:.+]] = llvm.shl %[[C1]], %[[GATHER_MODE_OFFSET]]
+  // CHECK: %[[SGPR0:.+]] = llvm.or %[[C1]], %[[GATHER_MODE_BIT]]
+
+  // CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32>
+  // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[SGPR0]], %[[V4I32_0_0]][%[[C0]] : i32]
+
+  %0 = amdgpu.make_gather_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_gather_base<i32,16>
+
+  // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+  // CHECK-DAG: %[[GATHER_MODE_OFFSET:.+]] = llvm.mlir.constant(30 : i32) : i32
+  // CHECK-DAG: %[[GATHER_MODE_BIT:.+]] = llvm.shl %[[C1]], %[[GATHER_MODE_OFFSET]]
+  // CHECK: %[[SGPR0_0:.+]] = llvm.or %[[C1]], %[[GATHER_MODE_BIT]]
+
+  // CHECK-DAG: %[[INDEX_SIZE_OFFSET:.+]] = llvm.mlir.constant(31 : i32) : i32
+  // CHECK-DAG: %[[INDEX_SIZE_BIT:.+]] = llvm.shl %[[C1]], %[[INDEX_SIZE_OFFSET]]
+  // CHECK: %[[SGPR0:.+]] = llvm.or %[[SGPR0_0]], %[[INDEX_SIZE_BIT]]
+
+  // CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32>
+  // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[SGPR0]], %[[V4I32_0_0]][%[[C0]] : i32]
+
+
+  %1 = amdgpu.make_gather_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_gather_base<i32,32>
+
+  func.return %0, %1 : !amdgpu.tdm_gather_base<i32,16>, !amdgpu.tdm_gather_base<i32, 32>
+}
+
+// -----
+
 // CHECK-LABEL: func @make_dma_descriptor
 // CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>)
 func.func @make_dma_descriptor(%base: !amdgpu.tdm_base<i32>) -> !amdgpu.tdm_descriptor {

>From e86509f75799b7da730563db3cdfb65a5c5bb7a6 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 11 Dec 2025 09:54:32 -0500
Subject: [PATCH 3/9] Add description

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 19 ++++++++++++++++++-
 1 file changed, 18 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 82398c2c82200..67d7f053a177b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1272,7 +1272,24 @@ class AMDGPU_DmaBaseOp<string mnemonic, Type outType> :
 
 def AMDGPU_MakeGatherDmaBaseOp : AMDGPU_DmaBaseOp<"make_gather_dma_base", AMDGPU_TDMGatherBaseType> {
   let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
-  let description = "TODO";
+
+  let description = [{
+    This operation creates a pair of addresses that will be used by `tensor_load_to_lds`
+    and `tensor_store_from_lds`.
+
+    This operation creates a value corresponding to the tensor descriptor (D#) group 0
+    found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect.
+
+    Unlike `make_dma_base`, this operation returns `!amdgpu.tdm_gather_base<$element_type, $index_size>`
+    which is only compatible with `make_gather_dma_descriptor`. Using the descriptor returned
+    by `make_gather_dma_descriptor` will set the `tensor_load_to_lds` and `tensor_store_from_lds` to gather mode.
+
+    ```mlir
+      %base = amdgpu.make_gather_dma_base %global[%idx0, %idx1], %lds[%idx2, %idx3] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_gather_base<i32, 16>
+      %descriptor = amdgpu.make_gather_dma_descriptor %base globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_gather_base<i32, 16> -> !amdgpu.tdm_descriptor
+      amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
+    ```
+  }];
 
   let extraClassDeclaration = [{
     constexpr int64_t isGather() {

>From 5828b4d2e9c707501e69dd2790fc773fb322ea87 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 11 Dec 2025 10:22:40 -0500
Subject: [PATCH 4/9] Add verifier to make_gather_dma_base

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  2 ++
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 23 ++++++++++++-------
 mlir/test/Dialect/AMDGPU/invalid.mlir         | 14 +++++++++++
 3 files changed, 31 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 67d7f053a177b..fd4cc603f7d5b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1291,6 +1291,8 @@ def AMDGPU_MakeGatherDmaBaseOp : AMDGPU_DmaBaseOp<"make_gather_dma_base", AMDGPU
     ```
   }];
 
+  let hasVerifier = 1;
+
   let extraClassDeclaration = [{
     constexpr int64_t isGather() {
       return true;
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index b7a665b0f5367..db3b224da8cd5 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -755,28 +755,35 @@ LogicalResult TransposeLoadOp::verify() {
 // MakeDmaBaseOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult MakeDmaBaseOp::verify() {
-
-  auto ldsType = cast<MemRefType>(getLds().getType());
-  auto globalType = cast<MemRefType>(getGlobal().getType());
+template <typename BaseOp>
+static LogicalResult verifyBase(BaseOp op) {
+  auto ldsType = cast<MemRefType>(op.getLds().getType());
+  auto globalType = cast<MemRefType>(op.getGlobal().getType());
   if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace()))
-    return emitOpError(
+    return op.emitOpError(
         "lds memref must have workgroup address space attribute.");
   if (!hasGlobalMemorySpace(globalType.getMemorySpace()))
-    return emitOpError(
+    return op.emitOpError(
         "global memref must have global address space attribute.");
 
   Type elementType = ldsType.getElementType();
   unsigned width = elementType.getIntOrFloatBitWidth();
 
   if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width))
-    return emitOpError(
+    return op.emitOpError(
                "element type must be 1, 2, 4, or 8 bytes long but type was ")
            << width << " bits long.";
-
   return success();
 }
 
+LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); }
+
+//===----------------------------------------------------------------------===//
+// MakeGatherDmaBaseOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); }
+
 //===----------------------------------------------------------------------===//
 // MakeDmaDescriptorOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 6308ea9a6a096..4f2ec8e7be8a5 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -371,6 +371,20 @@ func.func @make_dma_base_invalid_addressspace(%idx: index, %smem : memref<8xi32,
 
 // -----
 
+func.func @make_gather_dma_base_invalid_addressspace(%idx: index, %mem: memref<8xi32>) {
+  // expected-error at +1 {{'amdgpu.make_gather_dma_base' op lds memref must have workgroup address space attribute.}}
+  amdgpu.make_gather_dma_base %mem[%idx], %mem[%idx] : memref<8xi32>, memref<8xi32> -> !amdgpu.tdm_gather_base<i32, 16>
+}
+
+// -----
+
+func.func @make_gather_dma_base_invalid_addressspace(%idx: index, %smem : memref<8xi32, #gpu.address_space<workgroup>>) {
+  // expected-error at +1 {{'amdgpu.make_gather_dma_base' op global memref must have global address space attribute.}}
+  amdgpu.make_gather_dma_base %smem[%idx], %smem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_gather_base<i32, 16>
+}
+
+// -----
+
 func.func @make_dma_base_invalid_barrier(%base: !amdgpu.tdm_base<i32>, %barrier: memref<8xi32>, %idx: index) {
   // expected-error at +1 {{'amdgpu.make_dma_descriptor' op atomic barrier address must be in LDS.}}
   amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] atomicBarrier(%barrier[%idx] : memref<8xi32>) : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor

>From 59c12622fe62e424ebfe2427fc230f26e65f880d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 11 Dec 2025 10:31:45 -0500
Subject: [PATCH 5/9] Change indexSize to indexType

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td       | 10 +++++-----
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp |  3 ++-
 mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir     |  8 ++++----
 mlir/test/Dialect/AMDGPU/invalid.mlir               |  4 ++--
 4 files changed, 13 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index fd4cc603f7d5b..ad7f0cf158d8b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -119,13 +119,13 @@ def AMDGPU_TDMGatherBaseType : AMDGPU_Type<"TDMGatherBase", "tdm_gather_base"> {
     The value defined by this operation is only intended to be used by
     amdgpu.tdm_make_gather_descriptor.
   }];
-  let parameters = (ins "Type":$elementType, "unsigned":$indexSize);
+  let parameters = (ins "Type":$elementType, "Type":$indexType);
   let builders = [
-    TypeBuilderWithInferredContext<(ins "Type":$elementType, "unsigned": $indexSize), [{
-      return $_get(elementType.getContext(), elementType, indexSize);
+    TypeBuilderWithInferredContext<(ins "Type":$elementType, "Type": $indexType), [{
+      return $_get(elementType.getContext(), elementType, indexType);
     }]>
   ];
-  let assemblyFormat = "`<` $elementType `,` $indexSize `>`";
+  let assemblyFormat = "`<` $elementType `,` $indexType`>`";
 }
 
 def AMDGPU_TDMDescriptorType : AMDGPU_Type<"TDMDescriptor", "tdm_descriptor"> {
@@ -1280,7 +1280,7 @@ def AMDGPU_MakeGatherDmaBaseOp : AMDGPU_DmaBaseOp<"make_gather_dma_base", AMDGPU
     This operation creates a value corresponding to the tensor descriptor (D#) group 0
     found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect.
 
-    Unlike `make_dma_base`, this operation returns `!amdgpu.tdm_gather_base<$element_type, $index_size>`
+    Unlike `make_dma_base`, this operation returns `!amdgpu.tdm_gather_base<$element_type, $index_type>`
     which is only compatible with `make_gather_dma_descriptor`. Using the descriptor returned
     by `make_gather_dma_descriptor` will set the `tensor_load_to_lds` and `tensor_store_from_lds` to gather mode.
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ff4798961ef73..6497b970285f3 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2325,7 +2325,8 @@ struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
       sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
 
       auto type = cast<TDMGatherBaseType>(op.getResult().getType());
-      unsigned indexSize = type.getIndexSize();
+      Type indexType = type.getIndexType();
+      unsigned indexSize = indexType.getIntOrFloatBitWidth();
       assert(llvm::is_contained<unsigned>({16, 32}, indexSize) &&
              "expected index_size to be 16 or 32");
       unsigned idx = (indexSize / 16) - 1;
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
index 4cbfec52cf0a7..5dbd38571f630 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
@@ -243,7 +243,7 @@ func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>
 
 // CHECK-LABEL: func @make_gather_dma_base
 // CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32, 1>, %[[SMEM:.+]]: memref<8xi32, 3>)
-func.func @make_gather_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>, %smem: memref<8xi32,#gpu_lds_addrspace>) -> (!amdgpu.tdm_gather_base<i32,16>, !amdgpu.tdm_gather_base<i32, 32>) {
+func.func @make_gather_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>, %smem: memref<8xi32,#gpu_lds_addrspace>) -> (!amdgpu.tdm_gather_base<i32, i16>, !amdgpu.tdm_gather_base<i32, i32>) {
 
   // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
@@ -257,7 +257,7 @@ func.func @make_gather_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_add
   // CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32>
   // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[SGPR0]], %[[V4I32_0_0]][%[[C0]] : i32]
 
-  %0 = amdgpu.make_gather_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_gather_base<i32,16>
+  %0 = amdgpu.make_gather_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_gather_base<i32, i16>
 
   // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
@@ -276,9 +276,9 @@ func.func @make_gather_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_add
   // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[SGPR0]], %[[V4I32_0_0]][%[[C0]] : i32]
 
 
-  %1 = amdgpu.make_gather_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_gather_base<i32,32>
+  %1 = amdgpu.make_gather_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_gather_base<i32, i32>
 
-  func.return %0, %1 : !amdgpu.tdm_gather_base<i32,16>, !amdgpu.tdm_gather_base<i32, 32>
+  func.return %0, %1 : !amdgpu.tdm_gather_base<i32,i16>, !amdgpu.tdm_gather_base<i32,i32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 4f2ec8e7be8a5..2f6906cc64ee7 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -373,14 +373,14 @@ func.func @make_dma_base_invalid_addressspace(%idx: index, %smem : memref<8xi32,
 
 func.func @make_gather_dma_base_invalid_addressspace(%idx: index, %mem: memref<8xi32>) {
   // expected-error at +1 {{'amdgpu.make_gather_dma_base' op lds memref must have workgroup address space attribute.}}
-  amdgpu.make_gather_dma_base %mem[%idx], %mem[%idx] : memref<8xi32>, memref<8xi32> -> !amdgpu.tdm_gather_base<i32, 16>
+  amdgpu.make_gather_dma_base %mem[%idx], %mem[%idx] : memref<8xi32>, memref<8xi32> -> !amdgpu.tdm_gather_base<i32, i16>
 }
 
 // -----
 
 func.func @make_gather_dma_base_invalid_addressspace(%idx: index, %smem : memref<8xi32, #gpu.address_space<workgroup>>) {
   // expected-error at +1 {{'amdgpu.make_gather_dma_base' op global memref must have global address space attribute.}}
-  amdgpu.make_gather_dma_base %smem[%idx], %smem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_gather_base<i32, 16>
+  amdgpu.make_gather_dma_base %smem[%idx], %smem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_gather_base<i32, i16>
 }
 
 // -----

>From d80d2dfbc8c083055e3512c201ddb0ea8c365383 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 11 Dec 2025 10:47:37 -0500
Subject: [PATCH 6/9] Add verification to TDMGatherBaseType

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  1 +
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 17 +++++++++++++++++
 mlir/test/Dialect/AMDGPU/invalid.mlir         |  7 +++++++
 3 files changed, 25 insertions(+)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index ad7f0cf158d8b..d479976d16ad5 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -126,6 +126,7 @@ def AMDGPU_TDMGatherBaseType : AMDGPU_Type<"TDMGatherBase", "tdm_gather_base"> {
     }]>
   ];
   let assemblyFormat = "`<` $elementType `,` $indexType`>`";
+  let genVerifyDecl = 1;
 }
 
 def AMDGPU_TDMDescriptorType : AMDGPU_Type<"TDMDescriptor", "tdm_descriptor"> {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index db3b224da8cd5..00482ec9d01dd 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -782,6 +782,23 @@ LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); }
 // MakeGatherDmaBaseOp
 //===----------------------------------------------------------------------===//
 
+LogicalResult
+TDMGatherBaseType::verify(function_ref<InFlightDiagnostic()> emitError,
+                          Type elementType, Type indexType) {
+  unsigned width = elementType.getIntOrFloatBitWidth();
+  if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width))
+    return emitError()
+           << "element type must be 1, 2, 4, or 8 bytes wide but type "
+           << elementType << " is " << width / 8 << " bytes wide.";
+  MLIRContext *ctx = elementType.getContext();
+  Type i16 = IntegerType::get(ctx, 32);
+  Type i32 = IntegerType::get(ctx, 16);
+  if (!llvm::is_contained<Type>({i16, i32}, indexType))
+    return emitError() << "index type must be i16 or i32 but index type is "
+                       << indexType << ".";
+  return success();
+}
+
 LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 2f6906cc64ee7..a2d60c9cedea1 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -378,6 +378,13 @@ func.func @make_gather_dma_base_invalid_addressspace(%idx: index, %mem: memref<8
 
 // -----
 
+func.func @make_gather_dma_base_invalid_addressspace(%idx: index, %mem: memref<8xi32>) {
+  // expected-error at +1 {{index type must be i16 or i32 but index type is 'i64'.}}
+  amdgpu.make_gather_dma_base %mem[%idx], %mem[%idx] : memref<8xi32>, memref<8xi32> -> !amdgpu.tdm_gather_base<i32, i64>
+}
+
+// -----
+
 func.func @make_gather_dma_base_invalid_addressspace(%idx: index, %smem : memref<8xi32, #gpu.address_space<workgroup>>) {
   // expected-error at +1 {{'amdgpu.make_gather_dma_base' op global memref must have global address space attribute.}}
   amdgpu.make_gather_dma_base %smem[%idx], %smem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_gather_base<i32, i16>

>From f810a416284b364c92be1d9b805a410659a00f33 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 11 Dec 2025 10:49:33 -0500
Subject: [PATCH 7/9] documentation

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d479976d16ad5..e48614beb542b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1286,8 +1286,9 @@ def AMDGPU_MakeGatherDmaBaseOp : AMDGPU_DmaBaseOp<"make_gather_dma_base", AMDGPU
     by `make_gather_dma_descriptor` will set the `tensor_load_to_lds` and `tensor_store_from_lds` to gather mode.
 
     ```mlir
-      %base = amdgpu.make_gather_dma_base %global[%idx0, %idx1], %lds[%idx2, %idx3] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_gather_base<i32, 16>
-      %descriptor = amdgpu.make_gather_dma_descriptor %base globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_gather_base<i32, 16> -> !amdgpu.tdm_descriptor
+      %base = amdgpu.make_gather_dma_base %global[%idx0, %idx1], %lds[%idx2, %idx3] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_gather_base<i32, i16>
+      // %indices : i16
+      %descriptor = amdgpu.make_gather_dma_descriptor %base[%indices] globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_gather_base<i32, i16>, i16 -> !amdgpu.tdm_descriptor
       amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
     ```
   }];

>From c34761b548754a76e2e51dae17998be4b327a6d0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 11 Dec 2025 10:53:50 -0500
Subject: [PATCH 8/9] Change test name

---
 mlir/test/Dialect/AMDGPU/invalid.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index a2d60c9cedea1..a816553fbf951 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -378,7 +378,7 @@ func.func @make_gather_dma_base_invalid_addressspace(%idx: index, %mem: memref<8
 
 // -----
 
-func.func @make_gather_dma_base_invalid_addressspace(%idx: index, %mem: memref<8xi32>) {
+func.func @make_gather_dma_base_invalid_index_type(%idx: index, %mem: memref<8xi32>) {
   // expected-error at +1 {{index type must be i16 or i32 but index type is 'i64'.}}
   amdgpu.make_gather_dma_base %mem[%idx], %mem[%idx] : memref<8xi32>, memref<8xi32> -> !amdgpu.tdm_gather_base<i32, i64>
 }

>From b3292986df53a1e6915a00e82f9c3c6bc953d01d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 11 Dec 2025 10:57:49 -0500
Subject: [PATCH 9/9] Correct tests

---
 mlir/test/Dialect/AMDGPU/invalid.mlir | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index a816553fbf951..d3f0f43d039ae 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -367,6 +367,7 @@ func.func @make_dma_base_invalid_addressspace(%idx: index, %mem: memref<8xi32>)
 func.func @make_dma_base_invalid_addressspace(%idx: index, %smem : memref<8xi32, #gpu.address_space<workgroup>>) {
   // expected-error at +1 {{'amdgpu.make_dma_base' op global memref must have global address space attribute.}}
   amdgpu.make_dma_base %smem[%idx], %smem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
+  return
 }
 
 // -----
@@ -374,13 +375,15 @@ func.func @make_dma_base_invalid_addressspace(%idx: index, %smem : memref<8xi32,
 func.func @make_gather_dma_base_invalid_addressspace(%idx: index, %mem: memref<8xi32>) {
   // expected-error at +1 {{'amdgpu.make_gather_dma_base' op lds memref must have workgroup address space attribute.}}
   amdgpu.make_gather_dma_base %mem[%idx], %mem[%idx] : memref<8xi32>, memref<8xi32> -> !amdgpu.tdm_gather_base<i32, i16>
+  return
 }
 
 // -----
 
-func.func @make_gather_dma_base_invalid_index_type(%idx: index, %mem: memref<8xi32>) {
+func.func @make_gather_dma_base_invalid_index_type(%idx: index, %smem: memref<8xi32, #gpu.address_space<workgroup>>, %mem: memref<8xi32>) {
   // expected-error at +1 {{index type must be i16 or i32 but index type is 'i64'.}}
-  amdgpu.make_gather_dma_base %mem[%idx], %mem[%idx] : memref<8xi32>, memref<8xi32> -> !amdgpu.tdm_gather_base<i32, i64>
+  amdgpu.make_gather_dma_base %smem[%idx], %mem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32> -> !amdgpu.tdm_gather_base<i32, i64>
+  return
 }
 
 // -----
@@ -388,6 +391,7 @@ func.func @make_gather_dma_base_invalid_index_type(%idx: index, %mem: memref<8xi
 func.func @make_gather_dma_base_invalid_addressspace(%idx: index, %smem : memref<8xi32, #gpu.address_space<workgroup>>) {
   // expected-error at +1 {{'amdgpu.make_gather_dma_base' op global memref must have global address space attribute.}}
   amdgpu.make_gather_dma_base %smem[%idx], %smem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_gather_base<i32, i16>
+  return
 }
 
 // -----
@@ -395,6 +399,7 @@ func.func @make_gather_dma_base_invalid_addressspace(%idx: index, %smem : memref
 func.func @make_dma_base_invalid_barrier(%base: !amdgpu.tdm_base<i32>, %barrier: memref<8xi32>, %idx: index) {
   // expected-error at +1 {{'amdgpu.make_dma_descriptor' op atomic barrier address must be in LDS.}}
   amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] atomicBarrier(%barrier[%idx] : memref<8xi32>) : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+  return
 }
 
 // -----



More information about the Mlir-commits mailing list