[llvm] [mlir] [MLIR][AMDGPU] Adding dynamic size check to avoid subword buffer load (PR #135014)
Zhuoran Yin via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 14 08:31:28 PDT 2025
================
@@ -117,20 +145,94 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
}
Location loc = readOp.getLoc();
- Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
- readOp.getPadding());
- Value load = rewriter.create<vector::LoadOp>(
- loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
- Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
- readOp.getMask(), load, fill);
-
- // Insert a broadcasting op if required.
- if (requiresBroadcasting) {
- res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
- res);
+ Value src = readOp.getSource();
+
+ VectorType vectorType = readOp.getVectorType();
+ int64_t vectorSize = vectorType.getNumElements();
+ int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+ SmallVector<OpFoldResult> indices = readOp.getIndices();
+
+ auto stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+ OpFoldResult linearizedIndices;
+ std::tie(std::ignore, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, elementBitWidth, elementBitWidth,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(), indices);
+ Value linearIndex =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+
+ // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
+ // Note below doesn't give the correct result for the linearized size.
+ // Value totalSize = getValueOrCreateConstantIndexOp(
+ // rewriter, loc, linearizedInfo.linearizedSize);
+ // It compute the mutiplied sizes of all dimensions instead of taking
+ // the maximum of each dimension size * stride.
+ SmallVector<AffineExpr> productExpressions;
+ SmallVector<Value> productResults;
+ unsigned sourceRank = cast<ShapedType>(src.getType()).getRank();
+
+ SmallVector<AffineExpr> symbols(2 * sourceRank);
+ SmallVector<Value> offsetValues(2 * sourceRank);
+ bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
+ for (size_t i = 0; i < sourceRank; ++i) {
+ unsigned offsetIdx = 2 * i;
+ productExpressions.push_back(symbols[offsetIdx] * symbols[offsetIdx + 1]);
+ offsetValues[offsetIdx] = stridedMetadata.getStrides()[i];
+ offsetValues[offsetIdx + 1] = stridedMetadata.getSizes()[i];
}
- rewriter.replaceOp(readOp, res);
+ AffineMap maxMap = AffineMap::get(
+ /*dimCount=*/0, /*symbolCount=*/symbols.size(), productExpressions,
+ rewriter.getContext());
+ Value totalSize =
+ rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
+
+ // delta = bufferSize - linearizedOffset
+ Value vectorSizeOffset =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+ Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+
+ // 1) check if delta < vectorSize
+ Value isOutofBounds = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
+
+ // 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
+ Value deltaBytes = rewriter.create<arith::MulIOp>(
+ loc, delta,
+ rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
+ Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
+ loc, elementBitWidth < 32 ? 32 / elementBitWidth : 1);
----------------
jerryyin wrote:
Yep, this is `ceilDiv` to prevent zero from appear. I'll amend the implementation.
https://github.com/llvm/llvm-project/pull/135014
More information about the llvm-commits
mailing list