[Mlir-commits] [llvm] [mlir] [MLIR][AMDGPU] Adding dynamic size check to avoid subword buffer load (PR #135014)

Krzysztof Drewniak llvmlistbot at llvm.org
Wed Apr 9 13:03:44 PDT 2025


================
@@ -117,20 +121,85 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
     }
 
     Location loc = readOp.getLoc();
-    Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
-                                                  readOp.getPadding());
-    Value load = rewriter.create<vector::LoadOp>(
-        loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
-    Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
-                                                 readOp.getMask(), load, fill);
-
-    // Insert a broadcasting op if required.
-    if (requiresBroadcasting) {
-      res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
-                                                 res);
+    Value src = readOp.getSource();
+    MemRefType memRefType = cast<MemRefType>(src.getType());
+    ArrayRef<int64_t> shape = memRefType.getShape();
+
+    Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value stride = one;
+
+    // Compute the linear index by linearIndex += indices[i] * stride
+    for (int i = shape.size() - 1; i >= 0; --i) {
+      Value currentIndex = readOp.getIndices()[i];
+      Value strideIndexed =
+          rewriter.create<arith::MulIOp>(loc, currentIndex, stride);
+      linearIndex =
+          rewriter.create<arith::AddIOp>(loc, linearIndex, strideIndexed);
+
+      if (i == 0)
+        break;
+
+      // Update stride for the next dimension
+      Value nextStride;
+      if (shape[i] != ShapedType::kDynamic) {
+        nextStride = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+      } else {
+        nextStride = rewriter.create<memref::DimOp>(loc, src, i);
+      }
+      stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
+    }
+
+    // Add vector size offset to linear index
+    VectorType vectorType = readOp.getVectorType();
+    int64_t vectorSize = vectorType.getNumElements();
+    Value vectorSizeOffset =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+    Value upperBoundIndex =
+        rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
+
+    Value totalSize = one;
+    for (size_t i = 0; i < shape.size(); ++i) {
+      Value dimensionSize;
+      if (shape[i] != ShapedType::kDynamic) {
+        dimensionSize = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+      } else {
+        dimensionSize = rewriter.create<memref::DimOp>(loc, src, i);
+      }
+      totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
     }
 
-    rewriter.replaceOp(readOp, res);
+    Value isInBounds = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
+
+    auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+      Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+                                                   readOp.getPadding());
+      Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
+                                                  readOp.getSource(),
+                                                  readOp.getIndices());
+      Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+                                                  readOp.getMask(), load, fill);
+
+      // Insert a broadcasting op if required.
+      if (requiresBroadcasting) {
+        res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
+                                                  res);
+      }
+      rewriter.create<scf::YieldOp>(loc, res);
+    };
+
+    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+      Operation *read = builder.clone(*readOp.getOperation());
+      read->setAttr("amdgpu.transformed", builder.getUnitAttr());
----------------
krzysz00 wrote:

I rather don't like this attribute-based turning it off strategy.

Can we just place the transferRead in and invoke the masked load lowering right then and there?



https://github.com/llvm/llvm-project/pull/135014


More information about the Mlir-commits mailing list