[Mlir-commits] [mlir] [mlir][vector] Avoid use of vector.splat in transforms (PR #150279)

Diego Caballero llvmlistbot at llvm.org
Tue Jul 29 09:20:46 PDT 2025


================
@@ -40,20 +43,20 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
     Type eltType = dstType.getElementType();
 
-    // Scalar to any vector can use splat.
-    if (!srcType) {
-      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
-      return success();
-    }
+    // A broadcast from a scalar is considered to be in the lowered form.
+    if (!srcType)
+      return failure();
 
     // Determine rank of source and destination.
     int64_t srcRank = srcType.getRank();
     int64_t dstRank = dstType.getRank();
 
-    // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
     if (srcRank <= 1 && dstRank == 1) {
-      Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
-      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
+      SmallVector<int64_t> fullRankPosition(srcRank, 0);
+      Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
----------------
dcaballe wrote:

could we use `rewriter.create<...>` here?

https://github.com/llvm/llvm-project/pull/150279


More information about the Mlir-commits mailing list