[llvm] [mlir] [MLIR][AMDGPU] Adding dynamic size check to avoid subword buffer load (PR #135014)

Zhuoran Yin via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 14 07:31:48 PDT 2025


================
@@ -117,20 +121,85 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
     }
 
     Location loc = readOp.getLoc();
-    Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
-                                                  readOp.getPadding());
-    Value load = rewriter.create<vector::LoadOp>(
-        loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
-    Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
-                                                 readOp.getMask(), load, fill);
-
-    // Insert a broadcasting op if required.
-    if (requiresBroadcasting) {
-      res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
-                                                 res);
+    Value src = readOp.getSource();
+    MemRefType memRefType = cast<MemRefType>(src.getType());
+    ArrayRef<int64_t> shape = memRefType.getShape();
+
+    Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value stride = one;
+
+    // Compute the linear index by linearIndex += indices[i] * stride
+    for (int i = shape.size() - 1; i >= 0; --i) {
+      Value currentIndex = readOp.getIndices()[i];
+      Value strideIndexed =
+          rewriter.create<arith::MulIOp>(loc, currentIndex, stride);
+      linearIndex =
+          rewriter.create<arith::AddIOp>(loc, linearIndex, strideIndexed);
+
+      if (i == 0)
+        break;
+
+      // Update stride for the next dimension
+      Value nextStride;
+      if (shape[i] != ShapedType::kDynamic) {
+        nextStride = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+      } else {
+        nextStride = rewriter.create<memref::DimOp>(loc, src, i);
+      }
+      stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
+    }
+
+    // Add vector size offset to linear index
+    VectorType vectorType = readOp.getVectorType();
+    int64_t vectorSize = vectorType.getNumElements();
+    Value vectorSizeOffset =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+    Value upperBoundIndex =
+        rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
+
+    Value totalSize = one;
+    for (size_t i = 0; i < shape.size(); ++i) {
+      Value dimensionSize;
+      if (shape[i] != ShapedType::kDynamic) {
+        dimensionSize = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+      } else {
+        dimensionSize = rewriter.create<memref::DimOp>(loc, src, i);
+      }
+      totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
     }
 
-    rewriter.replaceOp(readOp, res);
+    Value isInBounds = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
+
+    auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+      Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+                                                   readOp.getPadding());
+      Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
+                                                  readOp.getSource(),
+                                                  readOp.getIndices());
+      Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+                                                  readOp.getMask(), load, fill);
+
+      // Insert a broadcasting op if required.
+      if (requiresBroadcasting) {
+        res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
+                                                  res);
+      }
+      rewriter.create<scf::YieldOp>(loc, res);
+    };
+
+    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+      Operation *read = builder.clone(*readOp.getOperation());
+      read->setAttr("amdgpu.transformed", builder.getUnitAttr());
----------------
jerryyin wrote:

I cannot delete the attribute-based turning off strategy because otherwise the pass will recursively create infinite amount of transfer_read on the else branch. By using the attribute, it marks transfer_read that has been populated and don't lower them again.

Please review the 3rd commit: a09cd51 for:
- Changing the attribute name to `amdgpu.buffer_transfer_read_needs_mask`
- Introducing the `populateVectorTransferLoweringPatterns` so this attribute get immediately destroyed in the next rewrite pattern (transfer_read to masked_load)

https://github.com/llvm/llvm-project/pull/135014


More information about the llvm-commits mailing list