[Mlir-commits] [mlir] [mlir][amdgpu] Continue lowering make_tdm_descriptor. (PR #171498)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Wed Dec 10 11:00:38 PST 2025
================
@@ -2680,16 +2717,228 @@ struct AMDGPUMakeDmaDescriptorLowering
return dgroup1;
}
+ Value setTensorDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
+ int64_t offset) const {
+ ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
+ ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
+ SmallVector<OpFoldResult> mixedGlobalSizes =
+ getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
+ if (mixedGlobalSizes.size() <= static_cast<unsigned long>(dimX))
+ return sgpr0;
+
+ OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
+ Value tensorDimX;
+ if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult))
+ tensorDimX =
+ createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+ else {
+ IntegerType i32 = rewriter.getI32Type();
+ tensorDimX = cast<Value>(tensorDimXOpFoldResult);
+ tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
+ }
+
+ return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
+ }
+
+ Value setTensorDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
+ }
+
+ Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
+ Location loc, Value accumulator,
+ Value value, int64_t shift) const {
+
+ IntegerType i32 = rewriter.getI32Type();
+ value = LLVM::TruncOp::create(rewriter, loc, i32, value);
+ return setValueAtOffset(rewriter, loc, accumulator, value, shift);
+ }
+
+ Value setLDSAddrIncrement(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr1, ArrayRef<Value> consts,
+ int64_t offset) const {
+ Value ldsAddrIncrement = adaptor.getLdsIncrement();
+ return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
+ }
+
+ std::pair<Value, Value>
+ setGlobalAddrIncrement(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
+ int64_t offset) const {
+ Value globalAddrIncrement = adaptor.getGlobalIncrement();
+ sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
+ globalAddrIncrement, offset);
+ Value shift = createI64Constant(rewriter, loc, 32);
+ globalAddrIncrement =
+ LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
+ constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
+ sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
+ globalAddrIncrement, offset + 32);
+ Value mask = createI32Constant(rewriter, loc, first16BitsHigh);
+ sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
+ return {sgpr2, sgpr3};
+ }
+
+ Value setTensorDim3OrLDSAddrIncrement(MakeDmaDescriptorOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Value sgpr1,
+ ArrayRef<Value> consts) const {
+ Value ldsIncrement = op.getLdsIncrement();
+ constexpr int64_t dim = 3;
+ constexpr int64_t offset = 32;
+ if (!ldsIncrement)
+ return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
+ offset);
+ return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
+ offset);
+ }
+
+ std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
+ MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc, Value sgpr2,
+ Value sgpr3, ArrayRef<Value> consts) const {
+ Value globalIncrement = op.getGlobalIncrement();
+ constexpr int32_t dim = 2;
+ constexpr int32_t offset = 64;
+ if (!globalIncrement)
+ return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
+ consts, dim, offset);
+ return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
+ consts, offset);
+ }
+
+ Value setIterateCount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr3, ArrayRef<Value> consts,
+ int32_t offset) const {
+ Value iterationCount = adaptor.getIterationCount();
+ IntegerType i32 = rewriter.getI32Type();
+ // pre-condition: iterationCount is in the inclusive interval [1, 256].
+ // TODO: validation if the value breaks the pre-condition.
+ // If the pre-condition fails, there is a possibility of
+ // affecting the higher bits. In a following PR add a flag that instruments
+ // conditions that need to be checked at runtime.
+ iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
+ iterationCount =
+ LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
+ return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
+ }
+
+ Value setTileDim3OrIterateCount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Value sgpr3,
+ ArrayRef<Value> consts) const {
+ Value iterateCount = op.getIterationCount();
+ constexpr int32_t dim = 2;
+ constexpr int32_t offset = 112;
+ if (!iterateCount)
+ return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
+ offset);
+
+ return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
+ }
+
+ Value getDGroup2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ ArrayRef<Value> consts) const {
+ IntegerType i32 = rewriter.getI32Type();
+ Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
+ assert(v4i32 && "expected type conversion to succeed.");
+
+ if (!op.getLdsIncrement() && op.getRank() == 2)
+ return LLVM::PoisonOp::create(rewriter, loc, v4i32);
+
+ constexpr int64_t sgprlen = 4;
+ Value sgprs[sgprlen];
+ for (int i = 0; i < sgprlen; i++)
+ sgprs[i] = consts[0];
+
+ sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
+ sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
+ sgprs[1], consts);
+ std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
+ op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
+ sgprs[3] =
+ setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
+
+ Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
+ for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
+ dgroup2 =
+ LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
+
+ return dgroup2;
+ }
+
+ std::pair<Value, Value>
+ setTensorDim3Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
+ constexpr int32_t dim = 3;
+ constexpr int32_t offset = 0;
+ return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
+ dim, offset);
+ }
+
+ std::pair<Value, Value> setTensorDim4(MakeDmaDescriptorOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Value sgpr1, Value sgpr2,
+ ArrayRef<Value> consts) const {
+ constexpr int32_t dim = 4;
+ constexpr int32_t offset = 48;
+ return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
+ offset);
+ }
+
+ Value setTileDim4(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr2, ArrayRef<Value> consts) const {
+ constexpr int32_t dim = 4;
+ constexpr int32_t offset = 80;
+ return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
+ }
+
+ Value getDGroup3(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ ArrayRef<Value> consts) const {
+ IntegerType i32 = rewriter.getI32Type();
+ Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
+ assert(v4i32 && "expected type conversion to succeed.");
+ if (!op.getLdsIncrement() && op.getRank() == 2)
----------------
amd-eochoalo wrote:
Thanks for pointing this out, I think I need to change the condition since `getLdsIncrement()` will return the value which may be zero. But I need to check whether the attribute is present or not.
https://github.com/llvm/llvm-project/pull/171498
More information about the Mlir-commits
mailing list