[Mlir-commits] [mlir] [mlir][vector] Sink vector.extract/splat into load/store ops (PR #134389)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Apr 8 10:37:39 PDT 2025
================
@@ -1103,6 +1103,127 @@ class ExtractOpFromElementwise final
}
};
+/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
+/// ```
+/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+/// vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+/// ```
+class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp op,
+ PatternRewriter &rewriter) const override {
+ auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
+ if (!loadOp)
+ return rewriter.notifyMatchFailure(op, "not a load op");
+
+ if (!loadOp->hasOneUse())
+ return rewriter.notifyMatchFailure(op, "expected single op use");
+
+ VectorType memVecType = loadOp.getVectorType();
+ if (memVecType.isScalable())
+ return rewriter.notifyMatchFailure(op,
+ "scalable vectors are not supported");
+
+ MemRefType memType = loadOp.getMemRefType();
+ if (isa<VectorType>(memType.getElementType()))
+ return rewriter.notifyMatchFailure(
+ op, "memrefs of vectors are not supported");
+
+ int64_t rankOffset = memType.getRank() - memVecType.getRank();
+ if (rankOffset < 0)
+ return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
+
+ auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
+ int64_t finalRank = 0;
+ if (resVecType)
+ finalRank = resVecType.getRank();
+
+ SmallVector<Value> indices = loadOp.getIndices();
+ SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(loadOp);
+ Location loc = loadOp.getLoc();
+ for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
+ OpFoldResult pos = extractPos[i - rankOffset];
+ if (isConstantIntValue(pos, 0))
+ continue;
+
+ Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
+
+ auto ovf = arith::IntegerOverflowFlags::nsw;
+ indices[i] = rewriter.create<arith::AddIOp>(loc, indices[i], offset, ovf);
+ }
+
+ Value base = loadOp.getBase();
+ if (resVecType) {
+ rewriter.replaceOpWithNewOp<vector::LoadOp>(op, resVecType, base,
+ indices);
+ } else {
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
+ }
+ rewriter.eraseOp(loadOp);
+ return success();
+ }
+};
+
+/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+/// ```
+/// %0 = vector.splat %arg2 : vector<1xf32>
+/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
+/// ```
+class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::StoreOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = op.getVectorType();
+ if (vecType.isScalable())
+ return rewriter.notifyMatchFailure(op,
+ "scalable vectors are not supported");
+
+ if (isa<VectorType>(op.getMemRefType().getElementType()))
+ return rewriter.notifyMatchFailure(
+ op, "memrefs of vectors are not supported");
+
+ if (vecType.getNumElements() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "only 1-element, vectors are supported");
+
+ Operation *splat = op.getValueToStore().getDefiningOp();
+ if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
+ return rewriter.notifyMatchFailure(op, "not a splat");
----------------
banach-space wrote:
```suggestion
return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
```
https://github.com/llvm/llvm-project/pull/134389
More information about the Mlir-commits
mailing list