[Mlir-commits] [mlir] [mlir][amdgpu] Continue lowering make_tdm_descriptor. (PR #171498)

Erick Ochoa Lopez llvmlistbot at llvm.org
Wed Dec 10 11:00:38 PST 2025


================
@@ -2680,16 +2717,228 @@ struct AMDGPUMakeDmaDescriptorLowering
     return dgroup1;
   }
 
+  Value setTensorDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+                      ConversionPatternRewriter &rewriter, Location loc,
+                      Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
+                      int64_t offset) const {
+    ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
+    ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
+    SmallVector<OpFoldResult> mixedGlobalSizes =
+        getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
+    if (mixedGlobalSizes.size() <= static_cast<unsigned long>(dimX))
+      return sgpr0;
+
+    OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
+    Value tensorDimX;
+    if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult))
+      tensorDimX =
+          createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+    else {
+      IntegerType i32 = rewriter.getI32Type();
+      tensorDimX = cast<Value>(tensorDimXOpFoldResult);
+      tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
+    }
+
+    return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
+  }
+
+  Value setTensorDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+                      ConversionPatternRewriter &rewriter, Location loc,
+                      Value sgpr0, ArrayRef<Value> consts) const {
+    return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
+  }
+
+  Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
+                                    Location loc, Value accumulator,
+                                    Value value, int64_t shift) const {
+
+    IntegerType i32 = rewriter.getI32Type();
+    value = LLVM::TruncOp::create(rewriter, loc, i32, value);
+    return setValueAtOffset(rewriter, loc, accumulator, value, shift);
+  }
+
+  Value setLDSAddrIncrement(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+                            ConversionPatternRewriter &rewriter, Location loc,
+                            Value sgpr1, ArrayRef<Value> consts,
+                            int64_t offset) const {
+    Value ldsAddrIncrement = adaptor.getLdsIncrement();
+    return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
+  }
+
+  std::pair<Value, Value>
+  setGlobalAddrIncrement(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+                         ConversionPatternRewriter &rewriter, Location loc,
+                         Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
+                         int64_t offset) const {
+    Value globalAddrIncrement = adaptor.getGlobalIncrement();
+    sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
+                                        globalAddrIncrement, offset);
+    Value shift = createI64Constant(rewriter, loc, 32);
+    globalAddrIncrement =
+        LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
+    constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
+    sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
+                                        globalAddrIncrement, offset + 32);
+    Value mask = createI32Constant(rewriter, loc, first16BitsHigh);
+    sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
+    return {sgpr2, sgpr3};
+  }
+
+  Value setTensorDim3OrLDSAddrIncrement(MakeDmaDescriptorOp op,
+                                        OpAdaptor adaptor,
+                                        ConversionPatternRewriter &rewriter,
+                                        Location loc, Value sgpr1,
+                                        ArrayRef<Value> consts) const {
+    Value ldsIncrement = op.getLdsIncrement();
+    constexpr int64_t dim = 3;
+    constexpr int64_t offset = 32;
+    if (!ldsIncrement)
+      return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
+                           offset);
+    return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
+                               offset);
+  }
+
+  std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
+      MakeDmaDescriptorOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter, Location loc, Value sgpr2,
+      Value sgpr3, ArrayRef<Value> consts) const {
+    Value globalIncrement = op.getGlobalIncrement();
+    constexpr int32_t dim = 2;
+    constexpr int32_t offset = 64;
+    if (!globalIncrement)
+      return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
+                                 consts, dim, offset);
+    return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
+                                  consts, offset);
+  }
+
+  Value setIterateCount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+                        ConversionPatternRewriter &rewriter, Location loc,
+                        Value sgpr3, ArrayRef<Value> consts,
+                        int32_t offset) const {
+    Value iterationCount = adaptor.getIterationCount();
+    IntegerType i32 = rewriter.getI32Type();
+    // pre-condition: iterationCount is in the inclusive interval [1, 256].
+    // TODO: validation if the value breaks the pre-condition.
+    // If the pre-condition fails, there is a possibility of
+    // affecting the higher bits. In a following PR add a flag that instruments
+    // conditions that need to be checked at runtime.
+    iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
+    iterationCount =
+        LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
+    return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
+  }
+
+  Value setTileDim3OrIterateCount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+                                  ConversionPatternRewriter &rewriter,
+                                  Location loc, Value sgpr3,
+                                  ArrayRef<Value> consts) const {
+    Value iterateCount = op.getIterationCount();
+    constexpr int32_t dim = 2;
+    constexpr int32_t offset = 112;
+    if (!iterateCount)
+      return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
+                         offset);
+
+    return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
+  }
+
+  Value getDGroup2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+                   ConversionPatternRewriter &rewriter, Location loc,
+                   ArrayRef<Value> consts) const {
+    IntegerType i32 = rewriter.getI32Type();
+    Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
+    assert(v4i32 && "expected type conversion to succeed.");
+
+    if (!op.getLdsIncrement() && op.getRank() == 2)
+      return LLVM::PoisonOp::create(rewriter, loc, v4i32);
+
+    constexpr int64_t sgprlen = 4;
+    Value sgprs[sgprlen];
+    for (int i = 0; i < sgprlen; i++)
+      sgprs[i] = consts[0];
+
+    sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
+    sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
+                                               sgprs[1], consts);
+    std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
+        op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
+    sgprs[3] =
+        setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
+
+    Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
+    for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
+      dgroup2 =
+          LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
+
+    return dgroup2;
+  }
+
+  std::pair<Value, Value>
+  setTensorDim3Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+                      ConversionPatternRewriter &rewriter, Location loc,
+                      Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
+    constexpr int32_t dim = 3;
+    constexpr int32_t offset = 0;
+    return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
+                               dim, offset);
+  }
+
+  std::pair<Value, Value> setTensorDim4(MakeDmaDescriptorOp op,
+                                        OpAdaptor adaptor,
+                                        ConversionPatternRewriter &rewriter,
+                                        Location loc, Value sgpr1, Value sgpr2,
+                                        ArrayRef<Value> consts) const {
+    constexpr int32_t dim = 4;
+    constexpr int32_t offset = 48;
+    return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
+                         offset);
+  }
+
+  Value setTileDim4(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+                    ConversionPatternRewriter &rewriter, Location loc,
+                    Value sgpr2, ArrayRef<Value> consts) const {
+    constexpr int32_t dim = 4;
+    constexpr int32_t offset = 80;
+    return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
+  }
+
+  Value getDGroup3(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+                   ConversionPatternRewriter &rewriter, Location loc,
+                   ArrayRef<Value> consts) const {
+    IntegerType i32 = rewriter.getI32Type();
+    Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
+    assert(v4i32 && "expected type conversion to succeed.");
+    if (!op.getLdsIncrement() && op.getRank() == 2)
----------------
amd-eochoalo wrote:

Thanks for pointing this out, I think I need to change the condition since `getLdsIncrement()` will return the value which may be zero. But I need to check whether the attribute is present or not.

https://github.com/llvm/llvm-project/pull/171498


More information about the Mlir-commits mailing list