[Mlir-commits] [mlir] 5123d36 - [mlir][amdgpu] Lower make_gather_dma_descriptor. (#172083)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 15 10:33:01 PST 2025


Author: Erick Ochoa Lopez
Date: 2025-12-15T13:32:57-05:00
New Revision: 5123d36c021eaf04ce74b6400c59ffff810853b9

URL: https://github.com/llvm/llvm-project/commit/5123d36c021eaf04ce74b6400c59ffff810853b9
DIFF: https://github.com/llvm/llvm-project/commit/5123d36c021eaf04ce74b6400c59ffff810853b9.diff

LOG: [mlir][amdgpu] Lower make_gather_dma_descriptor. (#172083)

* Makes `MakeDescriptorOp` a template for `make_dma_descriptor` and
`make_gather_dma_descriptor`.
* Makes verification and folder for `make_dma_descriptor` a template.
* Adds custom verification and folder for `make_dma_gather_descriptor`
based on tempalte.
* Adds `make_gather_dma_descriptor` op.
* Lowers `make_gather_dma_descriptor` to ROCDL.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
    mlir/include/mlir/IR/CommonTypeConstraints.td
    mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
    mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
    mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
    mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir
    mlir/test/Dialect/AMDGPU/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index a0b8682965b20..694bb96c85300 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1303,7 +1303,7 @@ def AMDGPU_MakeGatherDmaBaseOp : AMDGPU_DmaBaseOp<"make_gather_dma_base", AMDGPU
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
-    constexpr bool isGather() {
+    static constexpr bool isGather() {
       return true;
     }
   }];
@@ -1354,16 +1354,17 @@ def AMDGPU_MakeDmaBaseOp : AMDGPU_DmaBaseOp<"make_dma_base", AMDGPU_TDMBaseType>
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
-    constexpr bool isGather() {
+    static constexpr bool isGather() {
       return false;
     }
   }];
 }
 
-def AMDGPU_MakeDmaDescriptorOp :
-  AMDGPU_Op<"make_dma_descriptor", [Pure, AttrSizedOperandSegments]>,
-  Arguments<(ins
-    AMDGPU_TDMBaseType: $base,
+class AMDGPU_MakeDescriptorOp<string mnemonic> :
+  AMDGPU_Op<mnemonic, [Pure, AttrSizedOperandSegments]>,
+  Results<(outs AMDGPU_TDMDescriptorType: $desc)> {
+
+  dag baseArgs = (ins
     Variadic<Index>: $global_dynamic_sizes,
     DenseI64ArrayAttr: $global_static_sizes,
     Variadic<Index>: $global_dynamic_strides,
@@ -1378,9 +1379,66 @@ def AMDGPU_MakeDmaDescriptorOp :
     Variadic<Index>: $atomic_barrier_indices,
     Optional<Index>: $global_increment,
     Optional<I32>: $lds_increment,
-    Optional<Index>: $iteration_count)>,
-  Results<(outs AMDGPU_TDMDescriptorType: $desc)> {
+    Optional<Index>: $iteration_count);
+
+  code extraClassDeclarationBase = [{
+    int64_t getRank() {
+      return getGlobalStaticSizes().size();
+    }
+
+    unsigned getElementTypeWidth() {
+      return getBase().getType().getElementType().getIntOrFloatBitWidth();
+    }
+
+    SmallVector<OpFoldResult> getMixedGlobalSizes() {
+      return getMixedValues(getGlobalStaticSizes(), getGlobalDynamicSizes(), getContext());
+    }
+
+    SmallVector<OpFoldResult> getMixedGlobalStrides() {
+      return getMixedValues(getGlobalStaticStrides(), getGlobalDynamicStrides(), getContext());
+    }
+
+    SmallVector<OpFoldResult> getMixedSharedSizes() {
+      return getMixedValues(getSharedStaticSizes(), getSharedDynamicSizes(), getContext());
+    }
+
+  }];
+
+}
+
+def AMDGPU_MakeGatherDmaDescriptorOp : AMDGPU_MakeDescriptorOp<"make_gather_dma_descriptor"> {
+  dag args = (ins AMDGPU_TDMGatherBaseType: $base,
+                  AnyTypeOf<[VectorOfMinMaxLengthAndType<1, 8, [I32]>,
+                             VectorOfMinMaxLengthAndType<1, 16, [I16]>]>: $indices);
+  let arguments = !con(args, baseArgs);
+  let summary = "Make all descriptor groups needed by TensorLoadToLDS/TensorStoreFromLDS.";
+
+  let assemblyFormat = [{
+    $base `[` $indices `]`
+    `globalSize` custom<DynamicIndexList>($global_dynamic_sizes, $global_static_sizes)
+    `globalStride` custom<DynamicIndexList>($global_dynamic_strides, $global_static_strides)
+    `sharedSize` custom<DynamicIndexList>($shared_dynamic_sizes, $shared_static_sizes)
+    ( `padShared` `(` $pad_amount^ `every` $pad_interval `)` )?
+    ( `workgroupMask` $workgroup_mask^ ( `earlyTimeout` $early_timeout^)?)?
+    ( `atomicBarrier` `(` $atomic_barrier_address^ `[` $atomic_barrier_indices `]`
+                      `:` type($atomic_barrier_address) `)`)?
+    ( `iterate` $global_increment^ `,` $lds_increment `,` $iteration_count )?
+    attr-dict `:` qualified(type($base)) `,` type($indices) `->` type(results)
+  }];
+
+  let hasVerifier = 1;
+  let hasFolder = 1;
+
+  let extraClassDeclaration = extraClassDeclarationBase # [{
+    static constexpr bool isGather() {
+      return true;
+    }
+  }];
+}
 
+def AMDGPU_MakeDmaDescriptorOp : AMDGPU_MakeDescriptorOp<"make_dma_descriptor"> {
+  dag args = (ins AMDGPU_TDMBaseType: $base);
+  let arguments = !con(args, baseArgs);
   let summary = "Make all descriptor groups needed by TensorLoadToLDS/TensorStoreFromLDS.";
   let description = [{
      Make all descriptor groups needed by tensor memory operations.
@@ -1437,30 +1495,15 @@ def AMDGPU_MakeDmaDescriptorOp :
     attr-dict `:` qualified(type($base)) `->` type(results)
   }];
 
-  let extraClassDeclaration = [{
-    int64_t getRank() {
-      return getGlobalStaticSizes().size();
-    }
-
-    unsigned getElementTypeWidth() {
-      return getBase().getType().getElementType().getIntOrFloatBitWidth();
-    }
-
-    SmallVector<OpFoldResult> getMixedGlobalSizes() {
-      return getMixedValues(getGlobalStaticSizes(), getGlobalDynamicSizes(), getContext());
-    }
-
-    SmallVector<OpFoldResult> getMixedGlobalStrides() {
-      return getMixedValues(getGlobalStaticStrides(), getGlobalDynamicStrides(), getContext());
-    }
+  let hasVerifier = 1;
+  let hasFolder = 1;
 
-    SmallVector<OpFoldResult> getMixedSharedSizes() {
-      return getMixedValues(getSharedStaticSizes(), getSharedDynamicSizes(), getContext());
+  let extraClassDeclaration = extraClassDeclarationBase # [{
+    static constexpr bool isGather() {
+      return false;
     }
   }];
 
-  let hasVerifier = 1;
-  let hasFolder = 1;
 }
 
 #endif // AMDGPU

diff  --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 8427ba560c8aa..0fb4837e528be 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -537,6 +537,18 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
                            == }]
                          # allowedlength>)>]>;
 
+// Whether the number of elements of a vector is greater than
+// or equal to `minLength`.
+class IsVectorOfMinLengthPred<int minLength> :
+  And<[IsVectorOfNonZeroRankTypePred,
+       CPred<"::llvm::cast<::mlir::VectorType>($_self).getNumElements() >= " # minLength>]>;
+
+// Whether the number of elements of a vector is less than
+// or equal to `maxLength`.
+class IsVectorOfMaxLengthPred<int maxLength> :
+  And<[IsVectorOfNonZeroRankTypePred,
+       CPred<"::llvm::cast<::mlir::VectorType>($_self).getNumElements() <= " # maxLength>]>;
+
 // Whether the number of elements of a fixed-length vector is from the given
 // `allowedLengths` list
 class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
@@ -600,6 +612,20 @@ class VectorOfLength<list<int> allowedLengths> : Type<
   " of length " # !interleave(allowedLengths, "/"),
   "::mlir::VectorType">;
 
+// Any vector where the number of elements is more than
+// or equal to minLength.
+class VectorOfMinLength<int minLength> : Type<
+  IsVectorOfMinLengthPred<minLength>,
+  " of at least length " # minLength,
+  "::mlir::VectorType">;
+
+// Any vector where the number of elements is less than
+// or equal to maxLength.
+class VectorOfMaxLength<int maxLength> : Type<
+  IsVectorOfMaxLengthPred<maxLength>,
+  " of at most length " # maxLength,
+  "::mlir::VectorType">;
+
 // Any fixed-length vector where the number of elements is from the given
 // `allowedLengths` list
 class FixedVectorOfLength<list<int> allowedLengths> : Type<
@@ -623,6 +649,14 @@ class VectorOfLengthAndType<list<int> allowedLengths,
    VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
+// Any vector where the number of elements is between
+// `minLength` and `maxLength` (inclusive).
+class VectorOfMinMaxLengthAndType<int minLength, int maxLength,
+                                  list<Type> allowedTypes> : AllOfType<
+  [VectorOfNonZeroRankOf<allowedTypes>, VectorOfMinLength<minLength>, VectorOfMaxLength<maxLength>],
+   VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfMinLength<minLength>.summary # VectorOfMaxLength<maxLength>.summary,
+  "::mlir::VectorType">;
+
 class FixedVectorOfShapeAndType<list<int> shape, Type elType>: ShapedContainerType<
   [elType],
   And<[IsVectorOfShape<shape>, IsFixedVectorOfAnyRankTypePred]>,

diff  --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index af825ad42dd82..40569f6fa544c 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2311,18 +2311,18 @@ struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
 
     constexpr int32_t constlen = 4;
     Value consts[constlen];
-    for (int64_t i = 0; i < constlen; i++)
+    for (int64_t i = 0; i < constlen; ++i)
       consts[i] = createI32Constant(rewriter, loc, i);
 
     constexpr int32_t sgprslen = constlen;
     Value sgprs[sgprslen];
-    for (int64_t i = 0; i < sgprslen; i++) {
+    for (int64_t i = 0; i < sgprslen; ++i) {
       sgprs[i] = consts[0];
     }
 
     sgprs[0] = consts[1];
 
-    if (op.isGather()) {
+    if constexpr (BaseOp::isGather()) {
       sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
 
       auto type = cast<TDMGatherBaseType>(op.getResult().getType());
@@ -2382,19 +2382,18 @@ struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
   }
 };
 
-struct AMDGPUMakeDmaDescriptorLowering
-    : public ConvertOpToLLVMPattern<MakeDmaDescriptorOp> {
-  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+template <typename DescriptorOp>
+struct AMDGPULowerDescriptor : public ConvertOpToLLVMPattern<DescriptorOp> {
+  using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
+  using OpAdaptor = typename ConvertOpToLLVMPattern<DescriptorOp>::OpAdaptor;
 
-  AMDGPUMakeDmaDescriptorLowering(const LLVMTypeConverter &converter,
-                                  Chipset chipset)
-      : ConvertOpToLLVMPattern<MakeDmaDescriptorOp>(converter),
-        chipset(chipset) {}
+  AMDGPULowerDescriptor(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
   Chipset chipset;
 
   Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
 
-  Value setWorkgroupMask(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
                          ConversionPatternRewriter &rewriter, Location loc,
                          Value sgpr0) const {
     Value mask = op.getWorkgroupMask();
@@ -2408,7 +2407,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
   }
 
-  Value setDataSize(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
                     ConversionPatternRewriter &rewriter, Location loc,
                     Value sgpr0, ArrayRef<Value> consts) const {
     unsigned elementTypeWidthInBits = op.getElementTypeWidth();
@@ -2419,7 +2418,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
   }
 
-  Value setAtomicBarrier(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
                          ConversionPatternRewriter &rewriter, Location loc,
                          Value sgpr0, ArrayRef<Value> consts) const {
     if (!adaptor.getAtomicBarrierAddress())
@@ -2428,16 +2427,18 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
   }
 
-  Value setIterateEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
                          ConversionPatternRewriter &rewriter, Location loc,
                          Value sgpr0, ArrayRef<Value> consts) const {
     if (!adaptor.getGlobalIncrement())
       return sgpr0;
 
+    // Value is ignored when in gather mode.
+    // TODO: emit error earlier?
     return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
   }
 
-  Value setPadEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
                      ConversionPatternRewriter &rewriter, Location loc,
                      Value sgpr0, ArrayRef<Value> consts) const {
     if (!op.getPadAmount())
@@ -2446,7 +2447,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
   }
 
-  Value setEarlyTimeout(MakeDmaDescriptorOp op, OpAdaptor adaptorm,
+  Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
                         ConversionPatternRewriter &rewriter, Location loc,
                         Value sgpr0, ArrayRef<Value> consts) const {
     if (!op.getWorkgroupMask())
@@ -2455,7 +2456,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
   }
 
-  Value setPadInterval(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
                        ConversionPatternRewriter &rewriter, Location loc,
                        Value sgpr0, ArrayRef<Value> consts) const {
     if (!op.getPadAmount())
@@ -2476,7 +2477,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
   }
 
-  Value setPadAmount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
                      ConversionPatternRewriter &rewriter, Location loc,
                      Value sgpr0, ArrayRef<Value> consts) const {
     if (!op.getPadAmount())
@@ -2494,7 +2495,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
   }
 
-  Value setAtomicBarrierAddress(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter,
                                 Location loc, Value sgpr1,
                                 ArrayRef<Value> consts) const {
@@ -2505,9 +2506,9 @@ struct AMDGPUMakeDmaDescriptorLowering
     auto barrierAddressTy =
         cast<MemRefType>(op.getAtomicBarrierAddress().getType());
     ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
-    atomicBarrierAddress =
-        getStridedElementPtr(rewriter, loc, barrierAddressTy,
-                             atomicBarrierAddress, atomicBarrierIndices);
+    atomicBarrierAddress = ConvertToLLVMPattern::getStridedElementPtr(
+        rewriter, loc, barrierAddressTy, atomicBarrierAddress,
+        atomicBarrierIndices);
     IntegerType i32 = rewriter.getI32Type();
     // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
     // that the 3 LSBs are zero.
@@ -2524,8 +2525,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
   }
 
-  std::pair<Value, Value> setTensorDimX(MakeDmaDescriptorOp op,
-                                        OpAdaptor adaptor,
+  std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Value sgpr1, Value sgpr2,
                                         ArrayRef<Value> consts, uint64_t dimX,
@@ -2544,10 +2544,10 @@ struct AMDGPUMakeDmaDescriptorLowering
     // conditions that need to be checked at runtime. This could also be fixed
     // by saying that mixedGlobalSizes is a DynamicI32List.
     Value tensorDimX;
-    if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult))
+    if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
       tensorDimX =
           createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
-    else {
+    } else {
       IntegerType i32 = rewriter.getI32Type();
       tensorDimX = cast<Value>(tensorDimXOpFoldResult);
       tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
@@ -2561,8 +2561,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     return {sgpr1, sgpr2};
   }
 
-  std::pair<Value, Value> setTensorDim0(MakeDmaDescriptorOp op,
-                                        OpAdaptor adaptor,
+  std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Value sgpr1, Value sgpr2,
                                         ArrayRef<Value> consts) const {
@@ -2570,8 +2569,7 @@ struct AMDGPUMakeDmaDescriptorLowering
                          48);
   }
 
-  std::pair<Value, Value> setTensorDim1(MakeDmaDescriptorOp op,
-                                        OpAdaptor adaptor,
+  std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Value sgpr2, Value sgpr3,
                                         ArrayRef<Value> consts) const {
@@ -2579,7 +2577,7 @@ struct AMDGPUMakeDmaDescriptorLowering
                          80);
   }
 
-  Value setTileDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
                     ConversionPatternRewriter &rewriter, Location loc,
                     Value sgpr, ArrayRef<Value> consts, size_t dimX,
                     int64_t offset) const {
@@ -2599,10 +2597,10 @@ struct AMDGPUMakeDmaDescriptorLowering
     // checked at runtime. This could also be fixed by saying that
     // mixedSharedSizes is a DynamicI16List.
     Value tileDimX;
-    if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult))
+    if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
       tileDimX =
           createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
-    else {
+    } else {
       IntegerType i32 = rewriter.getI32Type();
       tileDimX = cast<Value>(tileDimXOpFoldResult);
       tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
@@ -2611,26 +2609,50 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
   }
 
-  Value setTileDim0(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
                     ConversionPatternRewriter &rewriter, Location loc,
                     Value sgpr3, ArrayRef<Value> consts) const {
     return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
   }
 
-  Value setTileDim1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
                     ConversionPatternRewriter &rewriter, Location loc,
                     Value sgpr4, ArrayRef<Value> consts) const {
     return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
   }
 
-  Value setTileDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
+                        ConversionPatternRewriter &rewriter, Location loc,
+                        Value sgpr4, ArrayRef<Value> consts) const {
+    auto type = cast<VectorType>(op.getIndices().getType());
+    ArrayRef<int64_t> shape = type.getShape();
+    assert(shape.size() == 1 && "expected shape to be of rank 1.");
+    unsigned length = shape.back();
+    assert(0 < length && length <= 16 && "expected length to be at most 16.");
+    Value value = createI32Constant(rewriter, loc, length);
+    return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
+  }
+
+  Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
+                                  ConversionPatternRewriter &rewriter,
+                                  Location loc, Value sgpr4,
+                                  ArrayRef<Value> consts) const {
+    if constexpr (DescriptorOp::isGather())
+      return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
+    return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
+  }
+
+  Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
                     ConversionPatternRewriter &rewriter, Location loc,
                     Value sgpr4, ArrayRef<Value> consts) const {
+    // Value is ignored when in gather mode.
+    if constexpr (DescriptorOp::isGather())
+      return sgpr4;
     return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
   }
 
   std::pair<Value, Value>
-  setTensorDimXStride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
                       ConversionPatternRewriter &rewriter, Location loc,
                       Value sgprY, Value sgprZ, ArrayRef<Value> consts,
                       size_t dimX, int64_t offset) const {
@@ -2676,7 +2698,7 @@ struct AMDGPUMakeDmaDescriptorLowering
   }
 
   std::pair<Value, Value>
-  setTensorDim0Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
                       ConversionPatternRewriter &rewriter, Location loc,
                       Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
     return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
@@ -2684,18 +2706,21 @@ struct AMDGPUMakeDmaDescriptorLowering
   }
 
   std::pair<Value, Value>
-  setTensorDim1Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
                       ConversionPatternRewriter &rewriter, Location loc,
                       Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
+    // Value is ignored when in gather mode.
+    if constexpr (DescriptorOp::isGather())
+      return {sgpr5, sgpr6};
     return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
                                1, 208);
   }
 
-  Value getDGroup1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
                    ConversionPatternRewriter &rewriter, Location loc,
                    ArrayRef<Value> consts) const {
     Value sgprs[8];
-    for (int64_t i = 0; i < 8; i++) {
+    for (int64_t i = 0; i < 8; ++i) {
       sgprs[i] = consts[0];
     }
 
@@ -2716,7 +2741,8 @@ struct AMDGPUMakeDmaDescriptorLowering
         setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
 
     sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
-    sgprs[4] = setTileDim1(op, adaptor, rewriter, loc, sgprs[4], consts);
+    sgprs[4] =
+        setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
     sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
     std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
         op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
@@ -2736,7 +2762,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     return dgroup1;
   }
 
-  Value setTensorDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
                       ConversionPatternRewriter &rewriter, Location loc,
                       Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
                       int64_t offset) const {
@@ -2749,10 +2775,10 @@ struct AMDGPUMakeDmaDescriptorLowering
 
     OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
     Value tensorDimX;
-    if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult))
+    if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
       tensorDimX =
           createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
-    else {
+    } else {
       IntegerType i32 = rewriter.getI32Type();
       tensorDimX = cast<Value>(tensorDimXOpFoldResult);
       tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
@@ -2761,7 +2787,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
   }
 
-  Value setTensorDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
                       ConversionPatternRewriter &rewriter, Location loc,
                       Value sgpr0, ArrayRef<Value> consts) const {
     return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
@@ -2776,7 +2802,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setValueAtOffset(rewriter, loc, accumulator, value, shift);
   }
 
-  Value setLDSAddrIncrement(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
                             ConversionPatternRewriter &rewriter, Location loc,
                             Value sgpr1, ArrayRef<Value> consts,
                             int64_t offset) const {
@@ -2785,7 +2811,7 @@ struct AMDGPUMakeDmaDescriptorLowering
   }
 
   std::pair<Value, Value>
-  setGlobalAddrIncrement(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
                          ConversionPatternRewriter &rewriter, Location loc,
                          Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
                          int64_t offset) const {
@@ -2803,8 +2829,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     return {sgpr2, sgpr3};
   }
 
-  Value setTensorDim3OrLDSAddrIncrement(MakeDmaDescriptorOp op,
-                                        OpAdaptor adaptor,
+  Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Value sgpr1,
                                         ArrayRef<Value> consts) const {
@@ -2819,9 +2844,8 @@ struct AMDGPUMakeDmaDescriptorLowering
   }
 
   std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
-      MakeDmaDescriptorOp op, OpAdaptor adaptor,
-      ConversionPatternRewriter &rewriter, Location loc, Value sgpr2,
-      Value sgpr3, ArrayRef<Value> consts) const {
+      DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
+      Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts) const {
     Value globalIncrement = op.getGlobalIncrement();
     constexpr int32_t dim = 2;
     constexpr int32_t offset = 64;
@@ -2832,7 +2856,7 @@ struct AMDGPUMakeDmaDescriptorLowering
                                   consts, offset);
   }
 
-  Value setIterateCount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
                         ConversionPatternRewriter &rewriter, Location loc,
                         Value sgpr3, ArrayRef<Value> consts,
                         int32_t offset) const {
@@ -2850,7 +2874,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
   }
 
-  Value setTileDim3OrIterateCount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter,
                                   Location loc, Value sgpr3,
                                   ArrayRef<Value> consts) const {
@@ -2864,9 +2888,17 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
   }
 
-  Value getDGroup2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
                    ConversionPatternRewriter &rewriter, Location loc,
                    ArrayRef<Value> consts) const {
+    if constexpr (DescriptorOp::isGather())
+      return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
+    return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
+  }
+
+  Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
+                            ConversionPatternRewriter &rewriter, Location loc,
+                            ArrayRef<Value> consts) const {
     IntegerType i32 = rewriter.getI32Type();
     Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
     assert(v4i32 && "expected type conversion to succeed.");
@@ -2877,7 +2909,7 @@ struct AMDGPUMakeDmaDescriptorLowering
 
     constexpr int64_t sgprlen = 4;
     Value sgprs[sgprlen];
-    for (int i = 0; i < sgprlen; i++)
+    for (int i = 0; i < sgprlen; ++i)
       sgprs[i] = consts[0];
 
     sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
@@ -2896,8 +2928,78 @@ struct AMDGPUMakeDmaDescriptorLowering
     return dgroup2;
   }
 
+  Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
+                         ConversionPatternRewriter &rewriter, Location loc,
+                         ArrayRef<Value> consts, bool firstHalf) const {
+    IntegerType i32 = rewriter.getI32Type();
+    Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
+    assert(v4i32 && "expected type conversion to succeed.");
+
+    Value indices = adaptor.getIndices();
+    auto vectorType = cast<VectorType>(indices.getType());
+    unsigned length = vectorType.getShape().back();
+    Type elementType = vectorType.getElementType();
+    unsigned maxLength = elementType == i32 ? 4 : 8;
+    int32_t offset = firstHalf ? 0 : maxLength;
+    unsigned discountedLength =
+        std::max(static_cast<int32_t>(length - offset), 0);
+
+    unsigned targetSize = std::min(maxLength, discountedLength);
+
+    SmallVector<Value> indicesVector;
+    for (unsigned i = offset; i < targetSize + offset; ++i) {
+      Value idx;
+      if (i < consts.size())
+        idx = consts[i];
+      else
+        idx = createI32Constant(rewriter, loc, i);
+      Value elem = LLVM::ExtractElementOp::create(rewriter, loc, indices, idx);
+      indicesVector.push_back(elem);
+    }
+
+    SmallVector<Value> indicesI32Vector;
+    if (elementType == i32) {
+      indicesI32Vector = indicesVector;
+    } else {
+      for (unsigned i = 0; i < targetSize; ++i) {
+        Value index = indicesVector[i];
+        indicesI32Vector.push_back(
+            LLVM::ZExtOp::create(rewriter, loc, i32, index));
+      }
+      if ((targetSize % 2) != 0)
+        // Add padding when not divisible by two.
+        indicesI32Vector.push_back(consts[0]);
+    }
+
+    SmallVector<Value> indicesToInsert;
+    if (elementType == i32) {
+      indicesToInsert = indicesI32Vector;
+    } else {
+      unsigned size = indicesI32Vector.size() / 2;
+      for (unsigned i = 0; i < size; ++i) {
+        Value first = indicesI32Vector[2 * i];
+        Value second = indicesI32Vector[2 * i + 1];
+        Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
+        indicesToInsert.push_back(joined);
+      }
+    }
+
+    Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
+    for (auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
+      dgroup =
+          LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
+
+    return dgroup;
+  }
+
+  Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
+                         ConversionPatternRewriter &rewriter, Location loc,
+                         ArrayRef<Value> consts) const {
+    return getGatherIndices(op, adaptor, rewriter, loc, consts, true);
+  }
+
   std::pair<Value, Value>
-  setTensorDim3Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
                       ConversionPatternRewriter &rewriter, Location loc,
                       Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
     constexpr int32_t dim = 3;
@@ -2906,8 +3008,7 @@ struct AMDGPUMakeDmaDescriptorLowering
                                dim, offset);
   }
 
-  std::pair<Value, Value> setTensorDim4(MakeDmaDescriptorOp op,
-                                        OpAdaptor adaptor,
+  std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
                                         ConversionPatternRewriter &rewriter,
                                         Location loc, Value sgpr1, Value sgpr2,
                                         ArrayRef<Value> consts) const {
@@ -2917,7 +3018,7 @@ struct AMDGPUMakeDmaDescriptorLowering
                          offset);
   }
 
-  Value setTileDim4(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
                     ConversionPatternRewriter &rewriter, Location loc,
                     Value sgpr2, ArrayRef<Value> consts) const {
     constexpr int32_t dim = 4;
@@ -2925,9 +3026,17 @@ struct AMDGPUMakeDmaDescriptorLowering
     return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
   }
 
-  Value getDGroup3(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
                    ConversionPatternRewriter &rewriter, Location loc,
                    ArrayRef<Value> consts) const {
+    if constexpr (DescriptorOp::isGather())
+      return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
+    return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
+  }
+
+  Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
+                            ConversionPatternRewriter &rewriter, Location loc,
+                            ArrayRef<Value> consts) const {
     IntegerType i32 = rewriter.getI32Type();
     Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
     assert(v4i32 && "expected type conversion to succeed.");
@@ -2937,7 +3046,7 @@ struct AMDGPUMakeDmaDescriptorLowering
 
     constexpr int32_t sgprlen = 4;
     Value sgprs[sgprlen];
-    for (int i = 0; i < sgprlen; i++)
+    for (int i = 0; i < sgprlen; ++i)
       sgprs[i] = consts[0];
 
     std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
@@ -2954,8 +3063,14 @@ struct AMDGPUMakeDmaDescriptorLowering
     return dgroup3;
   }
 
+  Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
+                         ConversionPatternRewriter &rewriter, Location loc,
+                         ArrayRef<Value> consts) const {
+    return getGatherIndices(op, adaptor, rewriter, loc, consts, false);
+  }
+
   LogicalResult
-  matchAndRewrite(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+  matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (chipset < kGfx1250)
       return op->emitOpError(
@@ -2969,7 +3084,7 @@ struct AMDGPUMakeDmaDescriptorLowering
     assert(v4i32 && "expected type conversion to succeed");
 
     SmallVector<Value> consts;
-    for (int64_t i = 0; i < 8; i++)
+    for (int64_t i = 0; i < 8; ++i)
       consts.push_back(createI32Constant(rewriter, loc, i));
 
     Value dgroup0 = this->getDGroup0(adaptor);
@@ -3047,6 +3162,10 @@ void mlir::populateAMDGPUTypeAndAttributeConversions(
     Type i32 = IntegerType::get(type.getContext(), 32);
     return typeConverter.convertType(VectorType::get(4, i32));
   });
+  typeConverter.addConversion([&](TDMGatherBaseType type) -> Type {
+    Type i32 = IntegerType::get(type.getContext(), 32);
+    return typeConverter.convertType(VectorType::get(4, i32));
+  });
 }
 
 void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
@@ -3075,6 +3194,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
       GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
       AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
       AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
-      AMDGPUMakeDmaDescriptorLowering>(converter, chipset);
+      AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
+      AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>>(converter, chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index b3969e97ff71c..1d3e882adcb67 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -818,50 +818,52 @@ LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); }
 // MakeDmaDescriptorOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult MakeDmaDescriptorOp::verify() {
-  ArrayRef<int64_t> globalStaticStrides = getGlobalStaticStrides();
+template <typename DescriptorOp>
+static LogicalResult verifyDescriptorOp(DescriptorOp op) {
+  ArrayRef<int64_t> globalStaticStrides = op.getGlobalStaticStrides();
 
   if (globalStaticStrides.empty())
-    return emitOpError("strides must not be empty.");
+    return op.emitOpError("strides must not be empty.");
   if (globalStaticStrides.back() != 1)
-    return emitOpError("strides for the innermost dimension must be 1.");
+    return op.emitOpError("strides for the innermost dimension must be 1.");
 
-  ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes();
+  ArrayRef<int64_t> globalStaticSizes = op.getGlobalStaticSizes();
   size_t rank = globalStaticSizes.size();
   if (rank > 5)
-    return emitOpError("tensor and tile must be at most of rank 5.");
+    return op.emitOpError("tensor and tile must be at most of rank 5.");
   if (rank != globalStaticStrides.size())
-    return emitOpError("strides and sizes must have same rank.");
+    return op.emitOpError("strides and sizes must have same rank.");
 
-  ArrayRef<int64_t> sharedStaticSizes = getSharedStaticSizes();
+  ArrayRef<int64_t> sharedStaticSizes = op.getSharedStaticSizes();
   if (rank != sharedStaticSizes.size())
-    return emitOpError("tensor must have same rank as tile.");
+    return op.emitOpError("tensor must have same rank as tile.");
 
-  unsigned elementTypeWidth = getElementTypeWidth();
+  unsigned elementTypeWidth = op.getElementTypeWidth();
   if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
-    return emitOpError(
+    return op.emitOpError(
                "element type width must be 1, 2, 4 or 8 bytes, but was ")
            << elementTypeWidth << " bits long";
 
-  if (Value atomicBarrierAddress = getAtomicBarrierAddress()) {
+  if (Value atomicBarrierAddress = op.getAtomicBarrierAddress()) {
     auto atomicBarrierAddressType =
         cast<MemRefType>(atomicBarrierAddress.getType());
     bool barrierInLDS =
         hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace());
     if (!barrierInLDS)
-      return emitOpError("atomic barrier address must be in LDS.");
+      return op.emitOpError("atomic barrier address must be in LDS.");
   }
 
-  if (getEarlyTimeout() && !getWorkgroupMask())
-    return emitOpError(
+  if (op.getEarlyTimeout() && !op.getWorkgroupMask())
+    return op.emitOpError(
         "early timeout does not apply when workgroup_mask is not set.");
   return success();
 }
 
-OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
-  SmallVector<OpFoldResult> mixedGlobalSizes(getMixedGlobalSizes());
-  SmallVector<OpFoldResult> mixedGlobalStrides(getMixedGlobalStrides());
-  SmallVector<OpFoldResult> mixedSharedSizes(getMixedSharedSizes());
+template <typename DescriptorOp, typename FoldAdaptor>
+static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor) {
+  SmallVector<OpFoldResult> mixedGlobalSizes(op.getMixedGlobalSizes());
+  SmallVector<OpFoldResult> mixedGlobalStrides(op.getMixedGlobalStrides());
+  SmallVector<OpFoldResult> mixedSharedSizes(op.getMixedSharedSizes());
 
   if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true,
                                   /*onlyNonZero=*/true)) &&
@@ -878,19 +880,49 @@ OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
 
   dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes,
                              staticGlobalSizes);
-  setGlobalStaticSizes(staticGlobalSizes);
-  getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
+  op.setGlobalStaticSizes(staticGlobalSizes);
+  op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
 
   dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides,
                              staticGlobalStrides);
-  setGlobalStaticStrides(staticGlobalStrides);
-  getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
+  op.setGlobalStaticStrides(staticGlobalStrides);
+  op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
 
   dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes,
                              staticSharedSizes);
-  setSharedStaticSizes(staticSharedSizes);
-  getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
-  return getResult();
+  op.setSharedStaticSizes(staticSharedSizes);
+  op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
+  return op.getResult();
+}
+
+LogicalResult MakeDmaDescriptorOp::verify() {
+  return verifyDescriptorOp(*this);
+}
+
+OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
+  return foldDescriptorOp(*this, adaptor);
+}
+
+//===----------------------------------------------------------------------===//
+// MakeGatherDmaDescriptorOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult MakeGatherDmaDescriptorOp::verify() {
+  ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes();
+  size_t rank = globalStaticSizes.size();
+  if (rank > 2)
+    return emitOpError(
+        "tensor and tile must be at most of rank two in gather mode.");
+  Value indices = getIndices();
+  Type elementType = cast<VectorType>(indices.getType()).getElementType();
+  if (elementType != getBase().getType().getIndexType())
+    return emitOpError("indices' element type must match base's element type.");
+
+  return verifyDescriptorOp(*this);
+}
+
+OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) {
+  return foldDescriptorOp(*this, adaptor);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
index 58014e52fa191..4979e85785970 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
@@ -772,3 +772,112 @@ func.func @make_dma_descriptor_workgroup_mask(%base: !amdgpu.tdm_base<i32>, %wg_
   %descriptor = amdgpu.make_dma_descriptor %base globalSize [128, 64] globalStride [64, 1] sharedSize [128, 64] workgroupMask %wg_mask earlyTimeout %timeout : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
   func.return %descriptor : !amdgpu.tdm_descriptor
 }
+
+// -----
+
+// CHECK-LABEL: func @make_gather_dma_descriptor
+// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_gather_base<i32, i16>, %[[INDICES:.+]]: vector<13xi16>)
+func.func @make_gather_dma_descriptor(%base: !amdgpu.tdm_gather_base<i32, i16>, %indices: vector<13xi16>) -> !amdgpu.tdm_descriptor {
+
+  // CHECK-DAG: %[[DGROUP0:.+]] = builtin.unrealized_conversion_cast %[[BASE]]
+
+  // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32)
+  // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32)
+  // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32)
+  // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32)
+  // CHECK-DAG: %[[C4:.+]] = llvm.mlir.constant(4 : i32)
+  // CHECK-DAG: %[[C5:.+]] = llvm.mlir.constant(5 : i32)
+  // CHECK-DAG: %[[C6:.+]] = llvm.mlir.constant(6 : i32)
+  // CHECK-DAG: %[[C7:.+]] = llvm.mlir.constant(7 : i32)
+
+  // CHECK: %[[SGPR4:.+]] = llvm.mlir.constant([[VALID_INDICES:13]] : i32) : i32
+
+  // CHECK: %[[V8I32:.+]] = llvm.mlir.poison : vector<8xi32>
+  // CHECK: %[[DGROUP1_0:.+]] = llvm.insertelement {{.*}}, %[[V8I32]][%[[C0]] : i32]
+  // CHECK: %[[DGROUP1_1:.+]] = llvm.insertelement {{.*}}, %[[DGROUP1_0]][%[[C1]] : i32]
+  // CHECK: %[[DGROUP1_2:.+]] = llvm.insertelement {{.*}}, %[[DGROUP1_1]][%[[C2]] : i32]
+  // CHECK: %[[DGROUP1_3:.+]] = llvm.insertelement {{.*}}, %[[DGROUP1_2]][%[[C3]] : i32]
+  // CHECK: %[[DGROUP1_4:.+]] = llvm.insertelement %[[SGPR4]], %[[DGROUP1_3]][%[[C4]] : i32]
+  // CHECK: %[[DGROUP1_5:.+]] = llvm.insertelement {{.*}}, %[[DGROUP1_4]][%[[C5]] : i32]
+  // CHECK: %[[DGROUP1_6:.+]] = llvm.insertelement {{.*}}, %[[DGROUP1_5]][%[[C6]] : i32]
+  // CHECK: %[[DGROUP1:.+]] = llvm.insertelement {{.*}}, %[[DGROUP1_6]][%[[C7]] : i32]
+
+  // CHECK-DAG: %[[IDX0:.+]] = llvm.extractelement %[[INDICES]][%[[C0]] : i32] : vector<13xi16>
+  // CHECK-DAG: %[[IDX1:.+]] = llvm.extractelement %[[INDICES]][%[[C1]] : i32] : vector<13xi16>
+  // CHECK-DAG: %[[IDX2:.+]] = llvm.extractelement %[[INDICES]][%[[C2]] : i32] : vector<13xi16>
+  // CHECK-DAG: %[[IDX3:.+]] = llvm.extractelement %[[INDICES]][%[[C3]] : i32] : vector<13xi16>
+  // CHECK-DAG: %[[IDX4:.+]] = llvm.extractelement %[[INDICES]][%[[C4]] : i32] : vector<13xi16>
+  // CHECK-DAG: %[[IDX5:.+]] = llvm.extractelement %[[INDICES]][%[[C5]] : i32] : vector<13xi16>
+  // CHECK-DAG: %[[IDX6:.+]] = llvm.extractelement %[[INDICES]][%[[C6]] : i32] : vector<13xi16>
+  // CHECK-DAG: %[[IDX7:.+]] = llvm.extractelement %[[INDICES]][%[[C7]] : i32] : vector<13xi16>
+
+  // CHECK: %[[IDX0_32:.+]] = llvm.zext %[[IDX0]] : i16 to i32
+  // CHECK: %[[IDX1_32:.+]] = llvm.zext %[[IDX1]] : i16 to i32
+  // CHECK: %[[IDX2_32:.+]] = llvm.zext %[[IDX2]] : i16 to i32
+  // CHECK: %[[IDX3_32:.+]] = llvm.zext %[[IDX3]] : i16 to i32
+  // CHECK: %[[IDX4_32:.+]] = llvm.zext %[[IDX4]] : i16 to i32
+  // CHECK: %[[IDX5_32:.+]] = llvm.zext %[[IDX5]] : i16 to i32
+  // CHECK: %[[IDX6_32:.+]] = llvm.zext %[[IDX6]] : i16 to i32
+  // CHECK: %[[IDX7_32:.+]] = llvm.zext %[[IDX7]] : i16 to i32
+
+  // CHECK: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[IDX1:.+]] = llvm.shl %[[IDX1_32]], %[[SHIFT]]
+  // CHECK: %[[SGPR0:.+]] = llvm.or disjoint %[[IDX0_32]], %[[IDX1]]
+
+  // CHECK: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[IDX3:.+]] = llvm.shl %[[IDX3_32]], %[[SHIFT]]
+  // CHECK: %[[SGPR1:.+]] = llvm.or disjoint %[[IDX2_32]], %[[IDX3]]
+
+  // CHECK: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[IDX5:.+]] = llvm.shl %[[IDX5_32]], %[[SHIFT]]
+  // CHECK: %[[SGPR2:.+]] = llvm.or disjoint %[[IDX4_32]], %[[IDX5]]
+
+  // CHECK: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[IDX7:.+]] = llvm.shl %[[IDX7_32]], %[[SHIFT]]
+  // CHECK: %[[SGPR3:.+]] = llvm.or disjoint %[[IDX6_32]], %[[IDX7]]
+
+  // CHECK: %[[DGROUP2_0:.+]] = llvm.mlir.poison : vector<4xi32>
+  // CHECK: %[[DGROUP2_1:.+]] = llvm.insertelement %[[SGPR0]], %[[DGROUP2_0]][%[[C0]] : i32]
+  // CHECK: %[[DGROUP2_2:.+]] = llvm.insertelement %[[SGPR1]], %[[DGROUP2_1]][%[[C1]] : i32]
+  // CHECK: %[[DGROUP2_3:.+]] = llvm.insertelement %[[SGPR2]], %[[DGROUP2_2]][%[[C2]] : i32]
+  // CHECK: %[[DGROUP2:.+]] = llvm.insertelement %[[SGPR3]], %[[DGROUP2_3]][%[[C3]] : i32]
+
+  // CHECK: %[[C8:.+]] = llvm.mlir.constant(8 : i32)
+  // CHECK: %[[IDX8:.+]] = llvm.extractelement %[[INDICES]][%[[C8]] : i32] : vector<13xi16>
+  // CHECK: %[[C9:.+]] = llvm.mlir.constant(9 : i32)
+  // CHECK: %[[IDX9:.+]] = llvm.extractelement %[[INDICES]][%[[C9]] : i32] : vector<13xi16>
+  // CHECK: %[[C10:.+]] = llvm.mlir.constant(10 : i32)
+  // CHECK: %[[IDX10:.+]] = llvm.extractelement %[[INDICES]][%[[C10]] : i32] : vector<13xi16>
+  // CHECK: %[[C11:.+]] = llvm.mlir.constant(11 : i32)
+  // CHECK: %[[IDX11:.+]] = llvm.extractelement %[[INDICES]][%[[C11]] : i32] : vector<13xi16>
+  // CHECK: %[[C12:.+]] = llvm.mlir.constant(12 : i32)
+  // CHECK: %[[IDX12:.+]] = llvm.extractelement %[[INDICES]][%[[C12]] : i32] : vector<13xi16>
+
+  // CHECK: %[[IDX8_32:.+]] = llvm.zext %[[IDX8]] : i16 to i32
+  // CHECK: %[[IDX9_32:.+]] = llvm.zext %[[IDX9]] : i16 to i32
+  // CHECK: %[[IDX10_32:.+]] = llvm.zext %[[IDX10]] : i16 to i32
+  // CHECK: %[[IDX11_32:.+]] = llvm.zext %[[IDX11]] : i16 to i32
+  // CHECK: %[[IDX12_32:.+]] = llvm.zext %[[IDX12]] : i16 to i32
+
+  // CHECK: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[IDX9:.+]] = llvm.shl %[[IDX9_32]], %[[SHIFT]]
+  // CHECK: %[[SGPR0:.+]] = llvm.or disjoint %[[IDX8_32]], %[[IDX9]]
+
+  // CHECK: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[IDX11:.+]] = llvm.shl %[[IDX11_32]], %[[SHIFT]]
+  // CHECK: %[[SGPR1:.+]] = llvm.or disjoint %[[IDX10_32]], %[[IDX11]]
+
+  // CHECK: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[IDX13:.+]] = llvm.shl %[[C0]], %[[SHIFT]]
+  // CHECK: %[[SGPR2:.+]] = llvm.or disjoint %[[IDX12_32]], %[[IDX13]]
+
+  // CHECK-DAG: %[[DGROUP3_0:.+]] = llvm.mlir.poison : vector<4xi32>
+  // CHECK: %[[DGROUP3_1:.+]] = llvm.insertelement %[[SGPR0]], %[[DGROUP3_0]][%[[C0]] : i32]
+  // CHECK: %[[DGROUP3_2:.+]] = llvm.insertelement %[[SGPR1]], %[[DGROUP3_1]][%[[C1]] : i32]
+  // CHECK: %[[DGROUP3:.+]] = llvm.insertelement %[[SGPR2]], %[[DGROUP3_2]][%[[C2]] : i32]
+
+  // CHECK: %[[DGROUPS:.+]] = builtin.unrealized_conversion_cast %[[DGROUP0]], %[[DGROUP1]], %[[DGROUP2]], %[[DGROUP3]]
+  %descriptor = amdgpu.make_gather_dma_descriptor %base[%indices] globalSize [128, 64] globalStride [64, 1] sharedSize [128, 64] : !amdgpu.tdm_gather_base<i32, i16>, vector<13xi16> -> !amdgpu.tdm_descriptor
+  func.return %descriptor : !amdgpu.tdm_descriptor
+}
+

diff  --git a/mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir
index dcb385384a2b8..06dc20deaf500 100644
--- a/mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --canonicalize %s | FileCheck %s
+// RUN: mlir-opt --canonicalize --split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: @make_dma_descriptor_fold
 // CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[IDX:.+]]: index, %[[I32:.+]]: i32)
@@ -17,3 +17,23 @@ func.func @make_dma_descriptor_fold(%base: !amdgpu.tdm_base<i32>, %idx: index, %
         : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
   func.return %0 : !amdgpu.tdm_descriptor
 }
+
+// -----
+
+// CHECK-LABEL: @make_gather_dma_descriptor_fold
+// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_gather_base<i32, i32>, %[[IDX:.+]]: index, %[[I32:.+]]: i32, %[[INDICES:.+]]: vector<8xi32>)
+func.func @make_gather_dma_descriptor_fold(%base: !amdgpu.tdm_gather_base<i32, i32>, %idx: index, %i32: i32, %indices: vector<8xi32>) -> !amdgpu.tdm_descriptor {
+  %c64 = arith.constant 64 : index
+
+  // CHECK: amdgpu.make_gather_dma_descriptor %[[BASE]][%[[INDICES]]]
+  %0 = amdgpu.make_gather_dma_descriptor %base [%indices]
+        // CHECK-SAME: globalSize [64, 64]
+        globalSize [%c64, %c64]
+        // CHECK-SAME: globalStride [64, 1]
+        globalStride [%c64, 1]
+        // CHECK-SAME: sharedSize [64, 64]
+        sharedSize [%c64, %c64]
+        iterate %idx, %i32, %idx
+        : !amdgpu.tdm_gather_base<i32, i32>, vector<8xi32> -> !amdgpu.tdm_descriptor
+  func.return %0 : !amdgpu.tdm_descriptor
+}

diff  --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index d3f0f43d039ae..9ece57e9ec6a3 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -442,3 +442,13 @@ func.func @make_dma_descriptor_invalid_shared_and_global_rank(%base: !amdgpu.tdm
   func.return
 }
 
+
+// -----
+
+// CHECK-LABEL: func @make_gather_dma_descriptor_invalid_index_types
+// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_gather_base<i32, i32>, %[[VEC:.+]]: vector<8xi32>)
+func.func @make_gather_dma_descriptor_invalid_index_types(%base: !amdgpu.tdm_gather_base<i32, i16>, %indices: vector<8xi32>) {
+  // expected-error at +1 {{'amdgpu.make_gather_dma_descriptor' op indices' element type must match base's element type.}}
+  amdgpu.make_gather_dma_descriptor %base[%indices] globalSize [4, 4] globalStride [1, 1] sharedSize [1, 2] : !amdgpu.tdm_gather_base<i32, i16>, vector<8xi32> -> !amdgpu.tdm_descriptor
+  func.return
+}


        


More information about the Mlir-commits mailing list