[Mlir-commits] [mlir] [mlir][vector] Towards deprecating vector.splat (PR #150279)

James Newling llvmlistbot at llvm.org
Wed Jul 23 10:36:28 PDT 2025


https://github.com/newling created https://github.com/llvm/llvm-project/pull/150279

This PR removes all logic pertaining to vector.splat in the Vector/Transforms directory. 

>From b16fc77828370498560b7e6a31c86f5fb4a28b7e Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 23 Jul 2025 10:32:21 -0700
Subject: [PATCH 1/2] changes in vector transforms

---
 .../Transforms/LowerVectorBroadcast.cpp       | 23 ++++----
 .../Vector/Transforms/LowerVectorTransfer.cpp |  2 +-
 ...sertExtractStridedSliceRewritePatterns.cpp |  2 +-
 .../Vector/Transforms/VectorTransforms.cpp    | 56 ++++++++++++-------
 .../vector-broadcast-lowering-transforms.mlir | 16 +++---
 ...ctor-outerproduct-lowering-transforms.mlir | 24 ++++----
 6 files changed, 71 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index cb8e566869cfd..fee73bd7b1afa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -28,7 +28,10 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
-/// Progressive lowering of BroadcastOp.
+
+/// Convert a vector.broadcast without a scalar operand to a lower rank
+/// vector.broadcast. vector.broadcast with a scalar operand is expected to be
+/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -40,20 +43,20 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
     Type eltType = dstType.getElementType();
 
-    // Scalar to any vector can use splat.
-    if (!srcType) {
-      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
-      return success();
-    }
+    // A broadcast from a scalar is considered to be in the lowered form.
+    if (!srcType)
+      return failure();
 
     // Determine rank of source and destination.
     int64_t srcRank = srcType.getRank();
     int64_t dstRank = dstType.getRank();
 
-    // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
-    if (srcRank <= 1 && dstRank == 1) {
-      Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
-      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
+    if (srcType.getNumElements() == 1 && dstRank == 1) {
+      SmallVector<int64_t> fullRankPosition(srcRank, 0);
+      Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
+                                            fullRankPosition);
+      assert(!isa<VectorType>(ext.getType()) && "expected scalar");
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext);
       return success();
     }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index e9109322ed3d8..7122c53e49780 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering
             read, "vector type is not rank 1, can't create masked load, needs "
                   "VectorToSCF");
 
-      Value fill = vector::SplatOp::create(
+      Value fill = vector::BroadcastOp::create(
           rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
       res = vector::MaskedLoadOp::create(
           rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72352d72bfe77..cbb9d4bbf0b1f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -303,7 +303,7 @@ class DecomposeNDExtractStridedSlice
     // Extract/insert on a lower ranked extract strided slice op.
     Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
                                            rewriter.getZeroAttr(elemType));
-    Value res = SplatOp::create(rewriter, loc, dstType, zero);
+    Value res = BroadcastOp::create(rewriter, loc, dstType, zero);
     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
          off += stride, ++idx) {
       Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 73ca327bb49c5..abcc49144187e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -940,7 +940,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
 
     Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
                                            rewriter.getZeroAttr(elemType));
-    Value res = SplatOp::create(rewriter, loc, castDstType, zero);
+    Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
 
     SmallVector<int64_t> sliceShape = {castDstLastDim};
     SmallVector<int64_t> strides = {1};
@@ -966,6 +966,23 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
   std::function<bool(BitCastOp)> controlFn;
 };
 
+/// If `value` is the result of a splat or broadcast operation, return the input
+/// of the splat/broadcast operation.
+static Value getBroadcastLikeSource(Value value) {
+
+  Operation *op = value.getDefiningOp();
+  if (!op)
+    return {};
+
+  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
+    return broadcast.getSource();
+
+  if (auto splat = dyn_cast<vector::SplatOp>(op))
+    return splat.getInput();
+
+  return {};
+}
+
 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
 ///
 /// Example:
@@ -1007,26 +1024,23 @@ struct ReorderElementwiseOpsOnBroadcast final
     }
 
     // Get the type of the lhs operand
-    auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
-    if (!lhsBcastOrSplat ||
-        !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
-      return failure();
-    auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+    Value lhsSource = getBroadcastLikeSource(op->getOperand(0));
+    if (!lhsSource)
+      return rewriter.notifyMatchFailure(
+          op, "operand #0 not the result of a broadcast");
+    Type lhsBcastOrSplatType = lhsSource.getType();
 
     // Make sure that all operands are broadcast from identical types:
     //  * scalar (`vector.broadcast` + `vector.splat`), or
     //  * vector (`vector.broadcast`).
     // Otherwise the re-ordering wouldn't be safe.
-    if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
-          auto bcast = val.getDefiningOp<vector::BroadcastOp>();
-          if (bcast)
-            return (bcast.getOperand().getType() == lhsBcastOrSplatType);
-          auto splat = val.getDefiningOp<vector::SplatOp>();
-          if (splat)
-            return (splat.getOperand().getType() == lhsBcastOrSplatType);
+    if (!llvm::all_of(op->getOperands(), [lhsBcastOrSplatType](Value val) {
+          if (auto source = getBroadcastLikeSource(val))
+            return source.getType() == lhsBcastOrSplatType;
           return false;
         })) {
-      return failure();
+      return rewriter.notifyMatchFailure(
+          op, "not all operands are broadcasts from the sametype");
     }
 
     // Collect the source values before broadcasting
@@ -1240,15 +1254,17 @@ class StoreOpFromSplatOrBroadcast final
       return rewriter.notifyMatchFailure(
           op, "only 1-element vectors are supported");
 
-    Operation *splat = op.getValueToStore().getDefiningOp();
-    if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
-      return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
+    Value toStore = op.getValueToStore();
+    Value source = getBroadcastLikeSource(toStore);
+    if (!source)
+      return rewriter.notifyMatchFailure(
+          op, "value to store is not from a broadcast");
 
     // Checking for single use so we can remove splat.
+    Operation *splat = toStore.getDefiningOp();
     if (!splat->hasOneUse())
       return rewriter.notifyMatchFailure(op, "expected single op use");
 
-    Value source = splat->getOperand(0);
     Value base = op.getBase();
     ValueRange indices = op.getIndices();
 
@@ -1298,13 +1314,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
   // Add in an offset if requested.
   if (off) {
     Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
-    Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o);
+    Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
     indices = arith::AddIOp::create(rewriter, loc, ov, indices);
   }
   // Construct the vector comparison.
   Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
   Value bounds =
-      vector::SplatOp::create(rewriter, loc, indices.getType(), bound);
+      vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
   return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
                                indices, bounds);
 }
diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
index 8e167a520260f..d5e344393b217 100644
--- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
@@ -2,7 +2,7 @@
 
 // CHECK-LABEL: func @broadcast_vec1d_from_scalar
 // CHECK-SAME: %[[A:.*0]]: f32
-// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
+// CHECK:      %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32>
 // CHECK:      return %[[T0]] : vector<2xf32>
 
 func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
@@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
 
 // CHECK-LABEL: func @broadcast_vec2d_from_scalar
 // CHECK-SAME: %[[A:.*0]]: f32
-// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
+// CHECK:      %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
 // CHECK:      return %[[T0]] : vector<2x3xf32>
 
 func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
 
 // CHECK-LABEL: func @broadcast_vec3d_from_scalar
 // CHECK-SAME: %[[A:.*0]]: f32
-// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
+// CHECK:      %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32>
 // CHECK:      return %[[T0]] : vector<2x3x4xf32>
 
 func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
 // CHECK-LABEL: func @broadcast_stretch
 // CHECK-SAME: %[[A:.*0]]: vector<1xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32>
 // CHECK:      return %[[T1]] : vector<4xf32>
 
 func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
@@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
 // CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
 // CHECK:      %[[U0:.*]] = ub.poison : vector<4x3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32>
-// CHECK:      %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK:      %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32>
-// CHECK:      %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK:      %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32>
-// CHECK:      %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
+// CHECK:      %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32>
 // CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32>
-// CHECK:      %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
+// CHECK:      %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32>
 // CHECK:      %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      return %[[T15]] : vector<4x3xf32>
 
diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
index 059d955f77313..5a8125ed67173 100644
--- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
@@ -5,11 +5,11 @@
 // CHECK-SAME: %[[B:.*1]]: vector<3xf32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
 // CHECK:      %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
 // CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK:      %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK:      %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
 // CHECK:      %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
 // CHECK:      return %[[T7]] : vector<2x3xf32>
@@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>,
 // CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
 // CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32>
 // CHECK:      %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
 // CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
 // CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK:      %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
+// CHECK:      %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32>
 // CHECK:      %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
 // CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
@@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>,
 // CHECK-SAME: %[[B:.*1]]: vector<3xi32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
 // CHECK:      %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
 // CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK:      %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
+// CHECK:      %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32>
 // CHECK:      %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
 // CHECK:      return %[[T7]] : vector<2x3xi32>
@@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
 // CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
 // CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32>
 // CHECK:      %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
 // CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
 // CHECK:      %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK:      %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
+// CHECK:      %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32>
 // CHECK:      %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32>
 // CHECK:      %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
@@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
 // CHECK-LABEL: func @axpy_fp(
 // CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
 // CHECK-SAME: %[[B:.*1]]: f32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
 // CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
 // CHECK: return %[[T1]] : vector<16xf32>
 func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
@@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
 // CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
 // CHECK-SAME: %[[B:.*1]]: f32,
 // CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
 // CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
 // CHECK: return %[[T1]] : vector<16xf32>
 func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
@@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>
 // CHECK-LABEL: func @axpy_int(
 // CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
 // CHECK-SAME: %[[B:.*1]]: i32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
 // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
 // CHECK: return %[[T1]] : vector<16xi32>
 func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
@@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
 // CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
 // CHECK-SAME: %[[B:.*1]]: i32,
 // CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
 // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
 // CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
 // CHECK: return %[[T2]] : vector<16xi32>

>From d6a87a551521b7477e7b4a9a16a54eba5e532fe5 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 23 Jul 2025 10:35:11 -0700
Subject: [PATCH 2/2] further simplification of PR

---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index fee73bd7b1afa..0a48eab3e6666 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -51,7 +51,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     int64_t srcRank = srcType.getRank();
     int64_t dstRank = dstType.getRank();
 
-    if (srcType.getNumElements() == 1 && dstRank == 1) {
+    if (srcRank <= 1 && dstRank == 1) {
       SmallVector<int64_t> fullRankPosition(srcRank, 0);
       Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
                                             fullRankPosition);



More information about the Mlir-commits mailing list