[Mlir-commits] [mlir] [mlir][vector] Sink vector.extract/splat into load/store ops (PR #134389)
Ivan Butygin
llvmlistbot at llvm.org
Thu Apr 10 04:03:44 PDT 2025
================
@@ -1103,6 +1103,127 @@ class ExtractOpFromElementwise final
}
};
+/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
+/// ```
+/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+/// vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+/// ```
+class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp op,
+ PatternRewriter &rewriter) const override {
+ auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
+ if (!loadOp)
+ return rewriter.notifyMatchFailure(op, "not a load op");
+
+ if (!loadOp->hasOneUse())
+ return rewriter.notifyMatchFailure(op, "expected single op use");
+
+ VectorType memVecType = loadOp.getVectorType();
+ if (memVecType.isScalable())
+ return rewriter.notifyMatchFailure(op,
+ "scalable vectors are not supported");
+
+ MemRefType memType = loadOp.getMemRefType();
+ if (isa<VectorType>(memType.getElementType()))
+ return rewriter.notifyMatchFailure(
+ op, "memrefs of vectors are not supported");
+
+ int64_t rankOffset = memType.getRank() - memVecType.getRank();
+ if (rankOffset < 0)
+ return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
+
+ auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
+ int64_t finalRank = 0;
+ if (resVecType)
+ finalRank = resVecType.getRank();
+
+ SmallVector<Value> indices = loadOp.getIndices();
+ SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(loadOp);
+ Location loc = loadOp.getLoc();
+ for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
+ OpFoldResult pos = extractPos[i - rankOffset];
+ if (isConstantIntValue(pos, 0))
+ continue;
+
+ Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
+
+ auto ovf = arith::IntegerOverflowFlags::nsw;
----------------
Hardcode84 wrote:
Not sure what the question again. nsw/nuw are hints to the compiler about some external behavior to allow more aggressive optimizations and setting them on by default in builder is bad idea (and LLVM proper doesn't set them by default either IIRC). Specifically for index/offset calculations we never expect/rely on them to overflow in signed way and can set `nsw` everywhere, but we do allow intermediate calculations to be negative (and we also allow negative strides on memrefs) so we can't set `nuw`. For this specific code we would probably never encounter negative values either, but I would still keep only `nsw` just to be safe.
https://github.com/llvm/llvm-project/pull/134389
More information about the Mlir-commits
mailing list