[Mlir-commits] [mlir] [mlir][Vector] Pattern to linearize broadcast (PR #163845)

James Newling llvmlistbot at llvm.org
Thu Oct 16 11:49:32 PDT 2025


https://github.com/newling created https://github.com/llvm/llvm-project/pull/163845

The PR https://github.com/llvm/llvm-project/pull/162167 removed a pattern to linearize vector.splat, without adding the equivalent pattern for vector.broadcast. This PR adds such a pattern, hopefully brining vector.broadcast up to full parity with vector.splat that has now been removed. 


>From b1477eebb87b8fc423a9829c10722c955d8e54a6 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 16 Oct 2025 11:50:26 -0700
Subject: [PATCH] add pattern and tests

---
 .../Vector/Transforms/VectorLinearize.cpp     | 48 ++++++++++++++++++-
 mlir/test/Dialect/Vector/linearize.mlir       | 41 ++++++++++++++++
 2 files changed, 87 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 1b656d82f3201..9b2f88b7bbe9d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -817,6 +817,50 @@ struct LinearizeVectorToElements final
   }
 };
 
+/// Convert broadcasts from scalars or 1-element vectors, such as
+///
+/// ```mlir
+///   vector.broadcast %value : f32 to vector<4x4xf32>
+/// ```
+///
+/// to broadcasts to rank-1 vectors, with shape_casts before/after as needed.
+/// The above becomes,
+///
+/// ```mlir
+///   %out_1d = vector.splat %value : f32 to vector<16xf32>
+///   %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// ```
+struct LinearizeVectorBroadcast final
+    : public OpConversionPattern<vector::BroadcastOp> {
+  using Base::Base;
+
+  LinearizeVectorBroadcast(const TypeConverter &typeConverter,
+                           MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    int numElements = 1;
+    Type sourceType = broadcastOp.getSourceType();
+    if (auto vecType = dyn_cast<VectorType>(sourceType)) {
+      numElements = vecType.getNumElements();
+    }
+
+    if (numElements != 1) {
+      return rewriter.notifyMatchFailure(
+          broadcastOp, "only broadcasts of single elements can be linearized.");
+    }
+
+    auto dstTy = getTypeConverter()->convertType(broadcastOp.getType());
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(broadcastOp, dstTy,
+                                                     adaptor.getSource());
+
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -909,8 +953,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
            LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
-           LinearizeVectorFromElements, LinearizeVectorToElements>(
-          typeConverter, patterns.getContext());
+           LinearizeVectorBroadcast, LinearizeVectorFromElements,
+           LinearizeVectorToElements>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index ee5cfbcda5c19..cbbc833d7a51d 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -428,6 +428,47 @@ func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
 
 // -----
 
+// CHECK-LABEL: linearize_vector_broadcast_scalar_source
+// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
+func.func @linearize_vector_broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> {
+
+  // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<8xi32>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32>
+  // CHECK: return %[[CAST]] : vector<4x2xi32>
+  %0 = vector.broadcast %arg0 : i32 to vector<4x2xi32>
+  return %0 : vector<4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: linearize_vector_broadcast_rank_two_source
+// CHECK-SAME: (%[[ARG:.*]]: vector<1x1xi32>) -> vector<4x2xi32>
+func.func @linearize_vector_broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> {
+
+  // CHECK: %[[CAST0:.*]] = vector.shape_cast %[[ARG]] : vector<1x1xi32> to vector<1xi32>
+  // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST0]] : vector<1xi32> to vector<8xi32>
+  // CHECK: %[[CAST1:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32>
+  // CHECK: return %[[CAST1]] : vector<4x2xi32>
+  %0 = vector.broadcast %arg0 : vector<1x1xi32> to vector<4x2xi32>
+  return %0 : vector<4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: linearize_scalable_vector_broadcast
+// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
+func.func @linearize_scalable_vector_broadcast(%arg0: i32) -> vector<4x[2]xi32> {
+
+  // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<[8]xi32>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<[8]xi32> to vector<4x[2]xi32>
+  // CHECK: return %[[CAST]] : vector<4x[2]xi32>
+  %0 = vector.broadcast %arg0 : i32 to vector<4x[2]xi32>
+  return %0 : vector<4x[2]xi32>
+
+}
+
+// -----
+
 // CHECK-LABEL: linearize_create_mask
 // CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
 func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {



More information about the Mlir-commits mailing list