[llvm] [mlir] [MLIR][AMDGPU] Adding dynamic size check to avoid subword buffer load (PR #135014)

Zhuoran Yin via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 14 14:11:10 PDT 2025


================
@@ -117,20 +145,94 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
     }
 
     Location loc = readOp.getLoc();
-    Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
-                                                  readOp.getPadding());
-    Value load = rewriter.create<vector::LoadOp>(
-        loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
-    Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
-                                                 readOp.getMask(), load, fill);
-
-    // Insert a broadcasting op if required.
-    if (requiresBroadcasting) {
-      res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
-                                                 res);
+    Value src = readOp.getSource();
+
+    VectorType vectorType = readOp.getVectorType();
+    int64_t vectorSize = vectorType.getNumElements();
+    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+    SmallVector<OpFoldResult> indices = readOp.getIndices();
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+    OpFoldResult linearizedIndices;
+    std::tie(std::ignore, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, elementBitWidth, elementBitWidth,
+            stridedMetadata.getConstifiedMixedOffset(),
+            stridedMetadata.getConstifiedMixedSizes(),
+            stridedMetadata.getConstifiedMixedStrides(), indices);
+    Value linearIndex =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+
+    // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
+    // Note below doesn't give the correct result for the linearized size.
+    // Value totalSize = getValueOrCreateConstantIndexOp(
+    //    rewriter, loc, linearizedInfo.linearizedSize);
+    // It compute the mutiplied sizes of all dimensions instead of taking
+    // the maximum of each dimension size * stride.
----------------
jerryyin wrote:

Done in ddee079

https://github.com/llvm/llvm-project/pull/135014


More information about the llvm-commits mailing list