[Mlir-commits] [mlir] [mlir][amdgpu] Adds make_dma_gather_base (PR #171857)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 11 08:58:53 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Erick Ochoa Lopez (amd-eochoalo)
<details>
<summary>Changes</summary>
* Adds `tdm_gather_base` type.
* Adds `make_dma_gather_base` op.
* Adds `make_dma_gather_base` lowering to ROCDL.
---
Patch is 22.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171857.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+74-7)
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+63-44)
- (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+31-7)
- (modified) mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir (+52-6)
- (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+26)
``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 56160d3e8fe85..e48614beb542b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -94,6 +94,9 @@ def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> {
let description = [{
This type is opaque and it is used to represent a struct of two addresses.
One address is in LDS while the other is in global memory.
+
+ The value defined by this operation is only intended to be used by
+ amdgpu.tdm_make_descriptor.
}];
let parameters = (ins "Type":$elementType);
let builders = [
@@ -104,6 +107,28 @@ def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> {
let assemblyFormat = "`<` $elementType `>`";
}
+def AMDGPU_TDMGatherBaseType : AMDGPU_Type<"TDMGatherBase", "tdm_gather_base"> {
+ let summary = "Pair of base addresses that move data between LDS and global storage.";
+ let description = [{
+ This type is opaque and it is used to represent a struct of two addresses.
+ One address is in LDS while the other is in global memory.
+
+ This operation is similar to amdgpu.tdm_make_base but intended to be
+ used in gather mode.
+
+ The value defined by this operation is only intended to be used by
+ amdgpu.tdm_make_gather_descriptor.
+ }];
+ let parameters = (ins "Type":$elementType, "Type":$indexType);
+ let builders = [
+ TypeBuilderWithInferredContext<(ins "Type":$elementType, "Type": $indexType), [{
+ return $_get(elementType.getContext(), elementType, indexType);
+ }]>
+ ];
+ let assemblyFormat = "`<` $elementType `,` $indexType`>`";
+ let genVerifyDecl = 1;
+}
+
def AMDGPU_TDMDescriptorType : AMDGPU_Type<"TDMDescriptor", "tdm_descriptor"> {
let summary = "Descriptors used in tensor store/load operations.";
let description = [{
@@ -1229,17 +1254,57 @@ def AMDGPU_ScaledMFMAOp :
let hasCanonicalizer = 1;
}
-def AMDGPU_MakeDmaBaseOp :
- AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments, AllElementTypesMatch<["global", "lds"]>]>,
+
+class AMDGPU_DmaBaseOp<string mnemonic, Type outType> :
+ AMDGPU_Op<mnemonic, [Pure, AttrSizedOperandSegments, AllElementTypesMatch<["global", "lds"]>]>,
Arguments<(ins Arg<AnyMemRef>:$global,
Variadic<Index>:$global_indices,
Arg<AnyMemRef>:$lds,
Variadic<Index>:$lds_indices)>,
- Results<(outs AMDGPU_TDMBaseType: $base)> {
+ Results<(outs outType: $base)> {
// TODO:
// * Add verifiers to make sure that the number of indices do not exceed the number of dimensions.
+ let assemblyFormat = [{
+ $global `[` $global_indices `]` `,` $lds `[` $lds_indices `]` attr-dict `:` type($global) `,` type($lds) `->` type(results)
+ }];
+}
+
+def AMDGPU_MakeGatherDmaBaseOp : AMDGPU_DmaBaseOp<"make_gather_dma_base", AMDGPU_TDMGatherBaseType> {
+ let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
+
+ let description = [{
+ This operation creates a pair of addresses that will be used by `tensor_load_to_lds`
+ and `tensor_store_from_lds`.
+
+ This operation creates a value corresponding to the tensor descriptor (D#) group 0
+ found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect.
+
+ Unlike `make_dma_base`, this operation returns `!amdgpu.tdm_gather_base<$element_type, $index_type>`
+ which is only compatible with `make_gather_dma_descriptor`. Using the descriptor returned
+ by `make_gather_dma_descriptor` will set the `tensor_load_to_lds` and `tensor_store_from_lds` to gather mode.
+
+ ```mlir
+ %base = amdgpu.make_gather_dma_base %global[%idx0, %idx1], %lds[%idx2, %idx3] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_gather_base<i32, i16>
+ // %indices : i16
+ %descriptor = amdgpu.make_gather_dma_descriptor %base[%indices] globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_gather_base<i32, i16>, i16 -> !amdgpu.tdm_descriptor
+ amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
+ ```
+ }];
+
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ constexpr int64_t isGather() {
+ return true;
+ }
+ }];
+}
+
+
+def AMDGPU_MakeDmaBaseOp : AMDGPU_DmaBaseOp<"make_dma_base", AMDGPU_TDMBaseType> {
+
let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
let description = [{
This operation creates a pair of addresses that will be used by tensor_load_to_lds
@@ -1279,11 +1344,13 @@ def AMDGPU_MakeDmaBaseOp :
These tensor DMA operations were introduced in gfx1250.
}];
- let assemblyFormat = [{
- $global `[` $global_indices `]` `,` $lds `[` $lds_indices `]` attr-dict `:` type($global) `,` type($lds) `->` type(results)
- }];
-
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ constexpr int64_t isGather() {
+ return false;
+ }
+ }];
}
def AMDGPU_MakeDmaDescriptorOp :
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 91154b846f567..6497b970285f3 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2276,45 +2276,87 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
}
};
-struct AMDGPUMakeDmaBaseLowering
- : public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
+ Value accumulator, Value value, int64_t shift) {
+ shift = shift % 32;
+ Value shiftAmount;
+ if (shift != 0) {
+ shiftAmount = createI32Constant(rewriter, loc, shift % 32);
+ value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
+ }
+
+ if (matchPattern(accumulator, mlir::m_Zero()))
+ return value;
+
+ return LLVM::OrOp::create(rewriter, loc, accumulator, value);
+}
+
+template <typename BaseOp>
+struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
+ using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
+ using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
- : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
+ : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
- matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
+ matchAndRewrite(BaseOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset < kGfx1250)
return op->emitOpError("make_dma_base is only supported on gfx1250");
Location loc = op.getLoc();
+ constexpr int32_t constlen = 4;
+ Value consts[constlen];
+ for (int64_t i = 0; i < constlen; i++)
+ consts[i] = createI32Constant(rewriter, loc, i);
+
+ constexpr int32_t sgprslen = constlen;
+ Value sgprs[sgprslen];
+ for (int64_t i = 0; i < sgprslen; i++) {
+ sgprs[i] = consts[0];
+ }
+
+ sgprs[0] = consts[1];
+
+ if (op.isGather()) {
+ sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
+
+ auto type = cast<TDMGatherBaseType>(op.getResult().getType());
+ Type indexType = type.getIndexType();
+ unsigned indexSize = indexType.getIntOrFloatBitWidth();
+ assert(llvm::is_contained<unsigned>({16, 32}, indexSize) &&
+ "expected index_size to be 16 or 32");
+ unsigned idx = (indexSize / 16) - 1;
+
+ if (idx)
+ sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[idx], 31);
+ }
+
ValueRange ldsIndices = adaptor.getLdsIndices();
Value lds = adaptor.getLds();
auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
- Value ldsPtr =
- getStridedElementPtr(rewriter, loc, ldsMemRefType, lds, ldsIndices);
+ Value ldsPtr = ConvertToLLVMPattern::getStridedElementPtr(
+ rewriter, loc, ldsMemRefType, lds, ldsIndices);
ValueRange globalIndices = adaptor.getGlobalIndices();
Value global = adaptor.getGlobal();
auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
- Value globalPtr = getStridedElementPtr(rewriter, loc, globalMemRefType,
- global, globalIndices);
+ Value globalPtr = ConvertToLLVMPattern::getStridedElementPtr(
+ rewriter, loc, globalMemRefType, global, globalIndices);
Type i32 = rewriter.getI32Type();
Type i64 = rewriter.getI64Type();
- Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
+ sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
Value castForGlobalAddr =
LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
- Value lowHalf =
- LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
+ sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
createI64Constant(rewriter, loc, 32));
@@ -2322,26 +2364,17 @@ struct AMDGPUMakeDmaBaseLowering
Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
- Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
+ highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
- Value typeField = createI32Constant(rewriter, loc, 2 << 30);
- Value highHalfPlusType =
- LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField);
-
- Value c0 = createI32Constant(rewriter, loc, 0);
- Value c1 = createI32Constant(rewriter, loc, 1);
- Value c2 = createI32Constant(rewriter, loc, 2);
- Value c3 = createI32Constant(rewriter, loc, 3);
+ sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
assert(v4i32 && "expected type conversion to succeed");
Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
- result = LLVM::InsertElementOp::create(rewriter, loc, result, c1, c0);
- result = LLVM::InsertElementOp::create(rewriter, loc, result,
- castForLdsAddr, c1);
- result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2);
- result = LLVM::InsertElementOp::create(rewriter, loc, result,
- highHalfPlusType, c3);
+
+ for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
+ result =
+ LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
rewriter.replaceOp(op, result);
return success();
@@ -2360,21 +2393,6 @@ struct AMDGPUMakeDmaDescriptorLowering
Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
- Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
- Value accumulator, Value value, int64_t shift) const {
- shift = shift % 32;
- Value shiftAmount;
- if (shift != 0) {
- shiftAmount = createI32Constant(rewriter, loc, shift % 32);
- value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
- }
-
- if (matchPattern(accumulator, mlir::m_Zero()))
- return value;
-
- return LLVM::OrOp::create(rewriter, loc, accumulator, value);
- }
-
Value setWorkgroupMask(MakeDmaDescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Location loc,
Value sgpr0) const {
@@ -2797,7 +2815,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
- AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
- chipset);
+ AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
+ AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
+ AMDGPUMakeDmaDescriptorLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index b7a665b0f5367..00482ec9d01dd 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -755,28 +755,52 @@ LogicalResult TransposeLoadOp::verify() {
// MakeDmaBaseOp
//===----------------------------------------------------------------------===//
-LogicalResult MakeDmaBaseOp::verify() {
-
- auto ldsType = cast<MemRefType>(getLds().getType());
- auto globalType = cast<MemRefType>(getGlobal().getType());
+template <typename BaseOp>
+static LogicalResult verifyBase(BaseOp op) {
+ auto ldsType = cast<MemRefType>(op.getLds().getType());
+ auto globalType = cast<MemRefType>(op.getGlobal().getType());
if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace()))
- return emitOpError(
+ return op.emitOpError(
"lds memref must have workgroup address space attribute.");
if (!hasGlobalMemorySpace(globalType.getMemorySpace()))
- return emitOpError(
+ return op.emitOpError(
"global memref must have global address space attribute.");
Type elementType = ldsType.getElementType();
unsigned width = elementType.getIntOrFloatBitWidth();
if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width))
- return emitOpError(
+ return op.emitOpError(
"element type must be 1, 2, 4, or 8 bytes long but type was ")
<< width << " bits long.";
+ return success();
+}
+
+LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); }
+
+//===----------------------------------------------------------------------===//
+// MakeGatherDmaBaseOp
+//===----------------------------------------------------------------------===//
+LogicalResult
+TDMGatherBaseType::verify(function_ref<InFlightDiagnostic()> emitError,
+ Type elementType, Type indexType) {
+ unsigned width = elementType.getIntOrFloatBitWidth();
+ if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width))
+ return emitError()
+ << "element type must be 1, 2, 4, or 8 bytes wide but type "
+ << elementType << " is " << width / 8 << " bytes wide.";
+ MLIRContext *ctx = elementType.getContext();
+ Type i16 = IntegerType::get(ctx, 32);
+ Type i32 = IntegerType::get(ctx, 16);
+ if (!llvm::is_contained<Type>({i16, i32}, indexType))
+ return emitError() << "index type must be i16 or i32 but index type is "
+ << indexType << ".";
return success();
}
+LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); }
+
//===----------------------------------------------------------------------===//
// MakeDmaDescriptorOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
index a94e17ab5b9a5..5dbd38571f630 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
@@ -200,6 +200,11 @@ func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>
// CHECK-DAG: %[[MEMREF_DESC_MEM:.+]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<8xi32, 1>
// CHECK-DAG: %[[MEMREF_DESC_SMEM:.+]] = builtin.unrealized_conversion_cast %[[SMEM]] : memref<8xi32, 3>
+ // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
// CHECK-DAG: %[[MEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_MEM]][1] : !llvm.struct<(ptr<1>
// CHECK-DAG: %[[SMEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_SMEM]][1] : !llvm.struct<(ptr<3>
@@ -216,14 +221,10 @@ func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>
// CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(33554431 : i32)
// CHECK: %[[VALID_MEM_INT_HIGH:.+]] = llvm.and %[[MEM_INT_HIGH]], %[[MASK]]
- // CHECK-DAG: %[[TYPE_FIELD:.+]] = llvm.mlir.constant(-2147483648 : i32)
+ // CHECK: %[[SHIFT:.+]] = llvm.mlir.constant(30 : i32)
+ // CHECK: %[[TYPE_FIELD:.+]] = llvm.shl %[[C2]], %[[SHIFT]]
// CHECK: %[[MEM_INT_HIGH_TYPE:.+]] = llvm.or %[[VALID_MEM_INT_HIGH]], %[[TYPE_FIELD]]
- // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
- // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
-
// CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32>
// CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[C1]], %[[V4I32_0_0]][%[[C0]] : i32]
// CHECK: %[[V4I32_0_2:.+]] = llvm.insertelement %[[SMEM_INT]], %[[V4I32_0_1]][%[[C1]] : i32]
@@ -237,6 +238,51 @@ func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>
// -----
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+
+// CHECK-LABEL: func @make_gather_dma_base
+// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32, 1>, %[[SMEM:.+]]: memref<8xi32, 3>)
+func.func @make_gather_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>, %smem: memref<8xi32,#gpu_lds_addrspace>) -> (!amdgpu.tdm_gather_base<i32, i16>, !amdgpu.tdm_gather_base<i32, i32>) {
+
+ // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+ // CHECK-DAG: %[[GATHER_MODE_OFFSET:.+]] = llvm.mlir.constant(30 : i32) : i32
+ // CHECK-DAG: %[[GATHER_MODE_BIT:.+]] = llvm.shl %[[C1]], %[[GATHER_MODE_OFFSET]]
+ // CHECK: %[[SGPR0:.+]] = llvm.or %[[C1]], %[[GATHER_MODE_BIT]]
+
+ // CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32>
+ // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[SGPR0]], %[[V4I32_0_0]][%[[C0]] : i32]
+
+ %0 = amdgpu.make_gather_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_gather_base<i32, i16>
+
+ // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+ // CHECK-DAG: %[[GATHER_MODE_OFFSET:.+]] = llvm.mlir.constant(30 : i32) : i32
+ // CHECK-DAG: %[[GATHER_MODE_BIT:.+]] = llvm.shl %[[C1]], %[[GATHER_MODE_OFFSET]]
+ // CHECK: %[[SGPR0_0:.+]] = llvm.or %[[C1]], %[[GATHER_MODE_BIT]]
+
+ // CHECK-DAG: %[[INDEX_SIZE_OFFSET:.+]] = llvm.mlir.constant(31 : i32) : i32
+ // CHECK-DAG: %[[INDEX_SIZE_BIT:.+]] = llvm.shl %[[C1]], %[[INDEX_SIZE_OFFSET]]
+ // CHECK: %[[SGPR0:.+]] = llvm.or %[[SGPR0_0]], %[[INDEX_SIZE_BIT]]
+
+ // CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32>
+ // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[SGPR0]], %[[V4I32_0_0]][%[[C0]] : i32]
+
+
+ %1 = amdgpu.make_gather_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_gather_base<i32, i32>
+
+ func.return %0, %1 : !amdgpu.tdm_gather_base<i32,i16>, !amdgpu.tdm_gather_base<i32,i32>
+}
+
+// -----
+
// CHECK-LABEL: func @make_dma_descriptor
// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>)
func.func @make_dma_descriptor(%base: !amdgpu.tdm_base<i32>) -> !amdgpu.tdm_descriptor {
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 6308ea9a6a096..d3f0f43d039ae 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -367,6 +367,31 @@ func.func @make_dma_base_invalid_addres...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/171857
More information about the Mlir-commits
mailing list