[Mlir-commits] [mlir] [mlir][amdgpu] Adds make_dma_gather_base (PR #171857)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 11 08:58:53 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Erick Ochoa Lopez (amd-eochoalo)

<details>
<summary>Changes</summary>

* Adds `tdm_gather_base` type.
* Adds `make_dma_gather_base` op.
* Adds `make_dma_gather_base` lowering to ROCDL.

---

Patch is 22.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171857.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+74-7) 
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+63-44) 
- (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+31-7) 
- (modified) mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir (+52-6) 
- (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+26) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 56160d3e8fe85..e48614beb542b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -94,6 +94,9 @@ def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> {
   let description = [{
     This type is opaque and it is used to represent a struct of two addresses.
     One address is in LDS while the other is in global memory.
+
+    The value defined by this operation is only intended to be used by
+    amdgpu.tdm_make_descriptor.
   }];
   let parameters = (ins "Type":$elementType);
   let builders = [
@@ -104,6 +107,28 @@ def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> {
   let assemblyFormat = "`<` $elementType `>`";
 }
 
+def AMDGPU_TDMGatherBaseType : AMDGPU_Type<"TDMGatherBase", "tdm_gather_base"> {
+  let summary = "Pair of base addresses that move data between LDS and global storage.";
+  let description = [{
+    This type is opaque and it is used to represent a struct of two addresses.
+    One address is in LDS while the other is in global memory.
+
+    This operation is similar to amdgpu.tdm_make_base but intended to be
+    used in gather mode.
+
+    The value defined by this operation is only intended to be used by
+    amdgpu.tdm_make_gather_descriptor.
+  }];
+  let parameters = (ins "Type":$elementType, "Type":$indexType);
+  let builders = [
+    TypeBuilderWithInferredContext<(ins "Type":$elementType, "Type": $indexType), [{
+      return $_get(elementType.getContext(), elementType, indexType);
+    }]>
+  ];
+  let assemblyFormat = "`<` $elementType `,` $indexType`>`";
+  let genVerifyDecl = 1;
+}
+
 def AMDGPU_TDMDescriptorType : AMDGPU_Type<"TDMDescriptor", "tdm_descriptor"> {
   let summary = "Descriptors used in tensor store/load operations.";
   let description = [{
@@ -1229,17 +1254,57 @@ def AMDGPU_ScaledMFMAOp :
   let hasCanonicalizer = 1;
 }
 
-def AMDGPU_MakeDmaBaseOp :
-    AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments, AllElementTypesMatch<["global", "lds"]>]>,
+
+class AMDGPU_DmaBaseOp<string mnemonic, Type outType> :
+    AMDGPU_Op<mnemonic, [Pure, AttrSizedOperandSegments, AllElementTypesMatch<["global", "lds"]>]>,
     Arguments<(ins Arg<AnyMemRef>:$global,
                    Variadic<Index>:$global_indices,
                    Arg<AnyMemRef>:$lds,
                    Variadic<Index>:$lds_indices)>,
-    Results<(outs AMDGPU_TDMBaseType: $base)> {
+    Results<(outs outType: $base)> {
 
   // TODO:
   // * Add verifiers to make sure that the number of indices do not exceed the number of dimensions.
 
+  let assemblyFormat = [{
+    $global `[` $global_indices `]` `,` $lds `[` $lds_indices `]` attr-dict `:` type($global) `,` type($lds) `->` type(results)
+  }];
+}
+
+def AMDGPU_MakeGatherDmaBaseOp : AMDGPU_DmaBaseOp<"make_gather_dma_base", AMDGPU_TDMGatherBaseType> {
+  let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
+
+  let description = [{
+    This operation creates a pair of addresses that will be used by `tensor_load_to_lds`
+    and `tensor_store_from_lds`.
+
+    This operation creates a value corresponding to the tensor descriptor (D#) group 0
+    found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect.
+
+    Unlike `make_dma_base`, this operation returns `!amdgpu.tdm_gather_base<$element_type, $index_type>`
+    which is only compatible with `make_gather_dma_descriptor`. Using the descriptor returned
+    by `make_gather_dma_descriptor` will set the `tensor_load_to_lds` and `tensor_store_from_lds` to gather mode.
+
+    ```mlir
+      %base = amdgpu.make_gather_dma_base %global[%idx0, %idx1], %lds[%idx2, %idx3] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_gather_base<i32, i16>
+      // %indices : i16
+      %descriptor = amdgpu.make_gather_dma_descriptor %base[%indices] globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_gather_base<i32, i16>, i16 -> !amdgpu.tdm_descriptor
+      amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
+    ```
+  }];
+
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    constexpr int64_t isGather() {
+      return true;
+    }
+  }];
+}
+
+
+def AMDGPU_MakeDmaBaseOp : AMDGPU_DmaBaseOp<"make_dma_base", AMDGPU_TDMBaseType> {
+
   let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
   let description = [{
     This operation creates a pair of addresses that will be used by tensor_load_to_lds
@@ -1279,11 +1344,13 @@ def AMDGPU_MakeDmaBaseOp :
     These tensor DMA operations were introduced in gfx1250.
   }];
 
-  let assemblyFormat = [{
-    $global `[` $global_indices `]` `,` $lds `[` $lds_indices `]` attr-dict `:` type($global) `,` type($lds) `->` type(results)
-  }];
-
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    constexpr int64_t isGather() {
+      return false;
+    }
+  }];
 }
 
 def AMDGPU_MakeDmaDescriptorOp :
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 91154b846f567..6497b970285f3 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2276,45 +2276,87 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
   }
 };
 
-struct AMDGPUMakeDmaBaseLowering
-    : public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
-  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
+                              Value accumulator, Value value, int64_t shift) {
+  shift = shift % 32;
+  Value shiftAmount;
+  if (shift != 0) {
+    shiftAmount = createI32Constant(rewriter, loc, shift % 32);
+    value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
+  }
+
+  if (matchPattern(accumulator, mlir::m_Zero()))
+    return value;
+
+  return LLVM::OrOp::create(rewriter, loc, accumulator, value);
+}
+
+template <typename BaseOp>
+struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
+  using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
+  using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
 
   AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
-      : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
+      : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
   Chipset chipset;
 
   LogicalResult
-  matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
+  matchAndRewrite(BaseOp op, Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (chipset < kGfx1250)
       return op->emitOpError("make_dma_base is only supported on gfx1250");
 
     Location loc = op.getLoc();
 
+    constexpr int32_t constlen = 4;
+    Value consts[constlen];
+    for (int64_t i = 0; i < constlen; i++)
+      consts[i] = createI32Constant(rewriter, loc, i);
+
+    constexpr int32_t sgprslen = constlen;
+    Value sgprs[sgprslen];
+    for (int64_t i = 0; i < sgprslen; i++) {
+      sgprs[i] = consts[0];
+    }
+
+    sgprs[0] = consts[1];
+
+    if (op.isGather()) {
+      sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
+
+      auto type = cast<TDMGatherBaseType>(op.getResult().getType());
+      Type indexType = type.getIndexType();
+      unsigned indexSize = indexType.getIntOrFloatBitWidth();
+      assert(llvm::is_contained<unsigned>({16, 32}, indexSize) &&
+             "expected index_size to be 16 or 32");
+      unsigned idx = (indexSize / 16) - 1;
+
+      if (idx)
+        sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[idx], 31);
+    }
+
     ValueRange ldsIndices = adaptor.getLdsIndices();
     Value lds = adaptor.getLds();
     auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
 
-    Value ldsPtr =
-        getStridedElementPtr(rewriter, loc, ldsMemRefType, lds, ldsIndices);
+    Value ldsPtr = ConvertToLLVMPattern::getStridedElementPtr(
+        rewriter, loc, ldsMemRefType, lds, ldsIndices);
 
     ValueRange globalIndices = adaptor.getGlobalIndices();
     Value global = adaptor.getGlobal();
     auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
 
-    Value globalPtr = getStridedElementPtr(rewriter, loc, globalMemRefType,
-                                           global, globalIndices);
+    Value globalPtr = ConvertToLLVMPattern::getStridedElementPtr(
+        rewriter, loc, globalMemRefType, global, globalIndices);
 
     Type i32 = rewriter.getI32Type();
     Type i64 = rewriter.getI64Type();
 
-    Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
+    sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
     Value castForGlobalAddr =
         LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
 
-    Value lowHalf =
-        LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
+    sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
 
     Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
                                        createI64Constant(rewriter, loc, 32));
@@ -2322,26 +2364,17 @@ struct AMDGPUMakeDmaBaseLowering
     Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
 
     Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
-    Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
+    highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
 
-    Value typeField = createI32Constant(rewriter, loc, 2 << 30);
-    Value highHalfPlusType =
-        LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField);
-
-    Value c0 = createI32Constant(rewriter, loc, 0);
-    Value c1 = createI32Constant(rewriter, loc, 1);
-    Value c2 = createI32Constant(rewriter, loc, 2);
-    Value c3 = createI32Constant(rewriter, loc, 3);
+    sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
 
     Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
     assert(v4i32 && "expected type conversion to succeed");
     Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
-    result = LLVM::InsertElementOp::create(rewriter, loc, result, c1, c0);
-    result = LLVM::InsertElementOp::create(rewriter, loc, result,
-                                           castForLdsAddr, c1);
-    result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2);
-    result = LLVM::InsertElementOp::create(rewriter, loc, result,
-                                           highHalfPlusType, c3);
+
+    for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
+      result =
+          LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
 
     rewriter.replaceOp(op, result);
     return success();
@@ -2360,21 +2393,6 @@ struct AMDGPUMakeDmaDescriptorLowering
 
   Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
 
-  Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
-                         Value accumulator, Value value, int64_t shift) const {
-    shift = shift % 32;
-    Value shiftAmount;
-    if (shift != 0) {
-      shiftAmount = createI32Constant(rewriter, loc, shift % 32);
-      value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
-    }
-
-    if (matchPattern(accumulator, mlir::m_Zero()))
-      return value;
-
-    return LLVM::OrOp::create(rewriter, loc, accumulator, value);
-  }
-
   Value setWorkgroupMask(MakeDmaDescriptorOp op, OpAdaptor adaptor,
                          ConversionPatternRewriter &rewriter, Location loc,
                          Value sgpr0) const {
@@ -2797,7 +2815,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
       ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
       PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
       GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
-      AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
-                                                                  chipset);
+      AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
+      AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
+      AMDGPUMakeDmaDescriptorLowering>(converter, chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index b7a665b0f5367..00482ec9d01dd 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -755,28 +755,52 @@ LogicalResult TransposeLoadOp::verify() {
 // MakeDmaBaseOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult MakeDmaBaseOp::verify() {
-
-  auto ldsType = cast<MemRefType>(getLds().getType());
-  auto globalType = cast<MemRefType>(getGlobal().getType());
+template <typename BaseOp>
+static LogicalResult verifyBase(BaseOp op) {
+  auto ldsType = cast<MemRefType>(op.getLds().getType());
+  auto globalType = cast<MemRefType>(op.getGlobal().getType());
   if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace()))
-    return emitOpError(
+    return op.emitOpError(
         "lds memref must have workgroup address space attribute.");
   if (!hasGlobalMemorySpace(globalType.getMemorySpace()))
-    return emitOpError(
+    return op.emitOpError(
         "global memref must have global address space attribute.");
 
   Type elementType = ldsType.getElementType();
   unsigned width = elementType.getIntOrFloatBitWidth();
 
   if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width))
-    return emitOpError(
+    return op.emitOpError(
                "element type must be 1, 2, 4, or 8 bytes long but type was ")
            << width << " bits long.";
+  return success();
+}
+
+LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); }
+
+//===----------------------------------------------------------------------===//
+// MakeGatherDmaBaseOp
+//===----------------------------------------------------------------------===//
 
+LogicalResult
+TDMGatherBaseType::verify(function_ref<InFlightDiagnostic()> emitError,
+                          Type elementType, Type indexType) {
+  unsigned width = elementType.getIntOrFloatBitWidth();
+  if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width))
+    return emitError()
+           << "element type must be 1, 2, 4, or 8 bytes wide but type "
+           << elementType << " is " << width / 8 << " bytes wide.";
+  MLIRContext *ctx = elementType.getContext();
+  Type i16 = IntegerType::get(ctx, 32);
+  Type i32 = IntegerType::get(ctx, 16);
+  if (!llvm::is_contained<Type>({i16, i32}, indexType))
+    return emitError() << "index type must be i16 or i32 but index type is "
+                       << indexType << ".";
   return success();
 }
 
+LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); }
+
 //===----------------------------------------------------------------------===//
 // MakeDmaDescriptorOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
index a94e17ab5b9a5..5dbd38571f630 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
@@ -200,6 +200,11 @@ func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>
   // CHECK-DAG: %[[MEMREF_DESC_MEM:.+]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<8xi32, 1>
   // CHECK-DAG: %[[MEMREF_DESC_SMEM:.+]] = builtin.unrealized_conversion_cast %[[SMEM]] : memref<8xi32, 3>
 
+  // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
   // CHECK-DAG: %[[MEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_MEM]][1] : !llvm.struct<(ptr<1>
   // CHECK-DAG: %[[SMEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_SMEM]][1] : !llvm.struct<(ptr<3>
 
@@ -216,14 +221,10 @@ func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>
   // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(33554431 : i32)
   // CHECK: %[[VALID_MEM_INT_HIGH:.+]] = llvm.and %[[MEM_INT_HIGH]], %[[MASK]]
 
-  // CHECK-DAG: %[[TYPE_FIELD:.+]] = llvm.mlir.constant(-2147483648 : i32)
+  // CHECK: %[[SHIFT:.+]] = llvm.mlir.constant(30 : i32)
+  // CHECK: %[[TYPE_FIELD:.+]] = llvm.shl %[[C2]], %[[SHIFT]]
   // CHECK: %[[MEM_INT_HIGH_TYPE:.+]] = llvm.or %[[VALID_MEM_INT_HIGH]], %[[TYPE_FIELD]]
 
-  // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
-  // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
-
   // CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32>
   // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[C1]], %[[V4I32_0_0]][%[[C0]] : i32]
   // CHECK: %[[V4I32_0_2:.+]] = llvm.insertelement %[[SMEM_INT]], %[[V4I32_0_1]][%[[C1]] : i32]
@@ -237,6 +238,51 @@ func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>
 
 // -----
 
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+
+// CHECK-LABEL: func @make_gather_dma_base
+// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32, 1>, %[[SMEM:.+]]: memref<8xi32, 3>)
+func.func @make_gather_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>, %smem: memref<8xi32,#gpu_lds_addrspace>) -> (!amdgpu.tdm_gather_base<i32, i16>, !amdgpu.tdm_gather_base<i32, i32>) {
+
+  // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+  // CHECK-DAG: %[[GATHER_MODE_OFFSET:.+]] = llvm.mlir.constant(30 : i32) : i32
+  // CHECK-DAG: %[[GATHER_MODE_BIT:.+]] = llvm.shl %[[C1]], %[[GATHER_MODE_OFFSET]]
+  // CHECK: %[[SGPR0:.+]] = llvm.or %[[C1]], %[[GATHER_MODE_BIT]]
+
+  // CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32>
+  // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[SGPR0]], %[[V4I32_0_0]][%[[C0]] : i32]
+
+  %0 = amdgpu.make_gather_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_gather_base<i32, i16>
+
+  // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+  // CHECK-DAG: %[[GATHER_MODE_OFFSET:.+]] = llvm.mlir.constant(30 : i32) : i32
+  // CHECK-DAG: %[[GATHER_MODE_BIT:.+]] = llvm.shl %[[C1]], %[[GATHER_MODE_OFFSET]]
+  // CHECK: %[[SGPR0_0:.+]] = llvm.or %[[C1]], %[[GATHER_MODE_BIT]]
+
+  // CHECK-DAG: %[[INDEX_SIZE_OFFSET:.+]] = llvm.mlir.constant(31 : i32) : i32
+  // CHECK-DAG: %[[INDEX_SIZE_BIT:.+]] = llvm.shl %[[C1]], %[[INDEX_SIZE_OFFSET]]
+  // CHECK: %[[SGPR0:.+]] = llvm.or %[[SGPR0_0]], %[[INDEX_SIZE_BIT]]
+
+  // CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32>
+  // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[SGPR0]], %[[V4I32_0_0]][%[[C0]] : i32]
+
+
+  %1 = amdgpu.make_gather_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_gather_base<i32, i32>
+
+  func.return %0, %1 : !amdgpu.tdm_gather_base<i32,i16>, !amdgpu.tdm_gather_base<i32,i32>
+}
+
+// -----
+
 // CHECK-LABEL: func @make_dma_descriptor
 // CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>)
 func.func @make_dma_descriptor(%base: !amdgpu.tdm_base<i32>) -> !amdgpu.tdm_descriptor {
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 6308ea9a6a096..d3f0f43d039ae 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -367,6 +367,31 @@ func.func @make_dma_base_invalid_addres...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/171857


More information about the Mlir-commits mailing list