[Mlir-commits] [mlir] [mlir][vector] Sink vector.extract/splat into load/store ops (PR #134389)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Apr 8 10:37:42 PDT 2025


================
@@ -1103,6 +1103,127 @@ class ExtractOpFromElementwise final
   }
 };
 
+/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
+/// ```
+///  vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+///  vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+/// ```
+class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
+    if (!loadOp)
+      return rewriter.notifyMatchFailure(op, "not a load op");
+
+    if (!loadOp->hasOneUse())
+      return rewriter.notifyMatchFailure(op, "expected single op use");
+
+    VectorType memVecType = loadOp.getVectorType();
+    if (memVecType.isScalable())
+      return rewriter.notifyMatchFailure(op,
+                                         "scalable vectors are not supported");
+
+    MemRefType memType = loadOp.getMemRefType();
+    if (isa<VectorType>(memType.getElementType()))
+      return rewriter.notifyMatchFailure(
+          op, "memrefs of vectors are not supported");
+
+    int64_t rankOffset = memType.getRank() - memVecType.getRank();
+    if (rankOffset < 0)
+      return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
+
+    auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
----------------
banach-space wrote:

Is this required? Why not re-use `memVecType`? `getVectorType()` returns the result type: https://github.com/llvm/llvm-project/blob/0fc7aec349394d4713bd88fb5f0319e39b96f187/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td#L1748-L1750

https://github.com/llvm/llvm-project/pull/134389


More information about the Mlir-commits mailing list