[llvm] [mlir] [MLIR][AMDGPU] Adding dynamic size check to avoid subword buffer load (PR #135014)
Zhuoran Yin via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 14 07:31:48 PDT 2025
================
@@ -117,20 +121,85 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
}
Location loc = readOp.getLoc();
- Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
- readOp.getPadding());
- Value load = rewriter.create<vector::LoadOp>(
- loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
- Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
- readOp.getMask(), load, fill);
-
- // Insert a broadcasting op if required.
- if (requiresBroadcasting) {
- res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
- res);
+ Value src = readOp.getSource();
+ MemRefType memRefType = cast<MemRefType>(src.getType());
+ ArrayRef<int64_t> shape = memRefType.getShape();
+
+ Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value stride = one;
+
+ // Compute the linear index by linearIndex += indices[i] * stride
+ for (int i = shape.size() - 1; i >= 0; --i) {
+ Value currentIndex = readOp.getIndices()[i];
+ Value strideIndexed =
+ rewriter.create<arith::MulIOp>(loc, currentIndex, stride);
+ linearIndex =
+ rewriter.create<arith::AddIOp>(loc, linearIndex, strideIndexed);
+
+ if (i == 0)
+ break;
+
+ // Update stride for the next dimension
+ Value nextStride;
+ if (shape[i] != ShapedType::kDynamic) {
+ nextStride = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+ } else {
+ nextStride = rewriter.create<memref::DimOp>(loc, src, i);
+ }
+ stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
+ }
+
+ // Add vector size offset to linear index
+ VectorType vectorType = readOp.getVectorType();
+ int64_t vectorSize = vectorType.getNumElements();
+ Value vectorSizeOffset =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+ Value upperBoundIndex =
+ rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
+
+ Value totalSize = one;
+ for (size_t i = 0; i < shape.size(); ++i) {
+ Value dimensionSize;
+ if (shape[i] != ShapedType::kDynamic) {
+ dimensionSize = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+ } else {
+ dimensionSize = rewriter.create<memref::DimOp>(loc, src, i);
+ }
+ totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
}
- rewriter.replaceOp(readOp, res);
+ Value isInBounds = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
+
+ auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+ Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+ readOp.getPadding());
+ Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
+ readOp.getSource(),
+ readOp.getIndices());
+ Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+ readOp.getMask(), load, fill);
+
+ // Insert a broadcasting op if required.
+ if (requiresBroadcasting) {
+ res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
+ res);
+ }
+ rewriter.create<scf::YieldOp>(loc, res);
+ };
+
+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+ Operation *read = builder.clone(*readOp.getOperation());
+ read->setAttr("amdgpu.transformed", builder.getUnitAttr());
----------------
jerryyin wrote:
I cannot delete the attribute-based turning off strategy because otherwise the pass will recursively create infinite amount of transfer_read on the else branch. By using the attribute, it marks transfer_read that has been populated and don't lower them again.
Please review the 3rd commit: a09cd51 for:
- Changing the attribute name to `amdgpu.buffer_transfer_read_needs_mask`
- Introducing the `populateVectorTransferLoweringPatterns` so this attribute get immediately destroyed in the next rewrite pattern (transfer_read to masked_load)
https://github.com/llvm/llvm-project/pull/135014
More information about the llvm-commits
mailing list