[llvm] [mlir] [MLIR][AMDGPU] Adding dynamic size check to avoid subword buffer load (PR #135014)

Krzysztof Drewniak via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 14 08:47:56 PDT 2025


================
@@ -117,20 +145,94 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
     }
 
     Location loc = readOp.getLoc();
-    Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
-                                                  readOp.getPadding());
-    Value load = rewriter.create<vector::LoadOp>(
-        loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
-    Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
-                                                 readOp.getMask(), load, fill);
-
-    // Insert a broadcasting op if required.
-    if (requiresBroadcasting) {
-      res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
-                                                 res);
+    Value src = readOp.getSource();
+
+    VectorType vectorType = readOp.getVectorType();
+    int64_t vectorSize = vectorType.getNumElements();
+    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+    SmallVector<OpFoldResult> indices = readOp.getIndices();
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+    OpFoldResult linearizedIndices;
+    std::tie(std::ignore, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, elementBitWidth, elementBitWidth,
+            stridedMetadata.getConstifiedMixedOffset(),
+            stridedMetadata.getConstifiedMixedSizes(),
+            stridedMetadata.getConstifiedMixedStrides(), indices);
+    Value linearIndex =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+
+    // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
+    // Note below doesn't give the correct result for the linearized size.
+    // Value totalSize = getValueOrCreateConstantIndexOp(
+    //    rewriter, loc, linearizedInfo.linearizedSize);
+    // It compute the mutiplied sizes of all dimensions instead of taking
+    // the maximum of each dimension size * stride.
+    SmallVector<AffineExpr> productExpressions;
+    SmallVector<Value> productResults;
+    unsigned sourceRank = cast<ShapedType>(src.getType()).getRank();
+
+    SmallVector<AffineExpr> symbols(2 * sourceRank);
+    SmallVector<Value> offsetValues(2 * sourceRank);
+    bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
+    for (size_t i = 0; i < sourceRank; ++i) {
+      unsigned offsetIdx = 2 * i;
+      productExpressions.push_back(symbols[offsetIdx] * symbols[offsetIdx + 1]);
+      offsetValues[offsetIdx] = stridedMetadata.getStrides()[i];
+      offsetValues[offsetIdx + 1] = stridedMetadata.getSizes()[i];
     }
 
-    rewriter.replaceOp(readOp, res);
+    AffineMap maxMap = AffineMap::get(
+        /*dimCount=*/0, /*symbolCount=*/symbols.size(), productExpressions,
+        rewriter.getContext());
+    Value totalSize =
+        rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
+
+    // delta = bufferSize - linearizedOffset
+    Value vectorSizeOffset =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+    Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+
+    // 1) check if delta < vectorSize
+    Value isOutofBounds = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
+
+    // 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
----------------
krzysz00 wrote:

I think there's phase ordering between this and narrow type emulation

If narrow type emulation runs first, fp4 loads become byte loads

Otherwise, 4-bit types need a multiply by 2, and 6-bit types ... yeah, implementing masked load on that seems extremely tricky and we should probably just bail on this pattern.

Maybe the right move is to not run any of this logic on sub-byte types with a note to contemplate this later

https://github.com/llvm/llvm-project/pull/135014


More information about the llvm-commits mailing list