[Mlir-commits] [mlir] [mlir][vector] Avoid use of vector.splat in transforms (PR #150279)
Diego Caballero
llvmlistbot at llvm.org
Tue Jul 29 09:20:46 PDT 2025
================
@@ -40,20 +43,20 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
Type eltType = dstType.getElementType();
- // Scalar to any vector can use splat.
- if (!srcType) {
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
- return success();
- }
+ // A broadcast from a scalar is considered to be in the lowered form.
+ if (!srcType)
+ return failure();
// Determine rank of source and destination.
int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
- // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
if (srcRank <= 1 && dstRank == 1) {
- Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
+ SmallVector<int64_t> fullRankPosition(srcRank, 0);
+ Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
----------------
dcaballe wrote:
could we use `rewriter.create<...>` here?
https://github.com/llvm/llvm-project/pull/150279
More information about the Mlir-commits
mailing list