[Mlir-commits] [mlir] [mlir][vector] Sink vector.extract/splat into load/store ops (PR #134389)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Apr 22 05:35:29 PDT 2025


================
@@ -1043,6 +1047,144 @@ class ExtractOpFromElementwise final
   }
 };
 
+static bool isSupportedMemSinkElementType(Type type) {
+  if (isa<IndexType>(type))
+    return true;
+
+  // Non-byte-aligned types are tricky, skip them.
+  return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
+}
+
+/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
+///
+/// Example:
+/// ```
+///  vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+///  vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+/// ```
+class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
+    if (!loadOp)
+      return rewriter.notifyMatchFailure(op, "expected a load op");
+
+    // Checking for single use so we won't duplicate load ops.
+    if (!loadOp->hasOneUse())
+      return rewriter.notifyMatchFailure(op, "expected single op use");
+
+    VectorType loadVecType = loadOp.getVectorType();
+    if (loadVecType.isScalable())
+      return rewriter.notifyMatchFailure(op,
+                                         "scalable vectors are not supported");
+
+    MemRefType memType = loadOp.getMemRefType();
+    if (!isSupportedMemSinkElementType(memType.getElementType()))
+      return rewriter.notifyMatchFailure(op, "unsupported memref element type");
+
+    int64_t rankOffset = memType.getRank() - loadVecType.getRank();
+    if (rankOffset < 0)
+      return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
+
+    auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
+    int64_t finalRank = 0;
+    if (extractVecType)
+      finalRank = extractVecType.getRank();
+
+    SmallVector<Value> indices = loadOp.getIndices();
+    SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
+
+    // There may be memory stores between the load and the extract op, so we
+    // need to make sure that the new load op is inserted at the same place as
+    // the original load op.
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(loadOp);
+    Location loc = loadOp.getLoc();
+    ArithIndexingBuilder idxBuilderf(rewriter, loc);
+    for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
+      OpFoldResult pos = extractPos[i - rankOffset];
+      if (isConstantIntValue(pos, 0))
+        continue;
+
+      Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
+      indices[i] = idxBuilderf.add(indices[i], offset);
+    }
+
+    Value base = loadOp.getBase();
+    if (extractVecType) {
+      rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
+                                                  indices);
+    } else {
+      rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
+    }
+    // We checked for single use so we can safely erase the load op.
+    rewriter.eraseOp(loadOp);
+    return success();
+  }
+};
+
+/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+///
+/// Example:
+/// ```
+/// %0 = vector.splat %arg2 : vector<1xf32>
+/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
+/// ```
+class StoreOpFromSplatOrBroadcast final
+    : public OpRewritePattern<vector::StoreOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::StoreOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = op.getVectorType();
+    if (vecType.isScalable())
+      return rewriter.notifyMatchFailure(op,
+                                         "scalable vectors are not supported");
+
+    if (isa<VectorType>(op.getMemRefType().getElementType()))
+      return rewriter.notifyMatchFailure(
+          op, "memrefs of vectors are not supported");
+
+    if (vecType.getNumElements() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "only 1-element, vectors are supported");
----------------
banach-space wrote:

[nit]
```suggestion
          op, "only 1-element vectors are supported");
```

https://github.com/llvm/llvm-project/pull/134389


More information about the Mlir-commits mailing list