[Mlir-commits] [mlir] [mlir][Vector] Pattern to linearize broadcast (PR #163845)
James Newling
llvmlistbot at llvm.org
Fri Oct 17 16:20:22 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/163845
>From b1477eebb87b8fc423a9829c10722c955d8e54a6 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 16 Oct 2025 11:50:26 -0700
Subject: [PATCH 1/2] add pattern and tests
---
.../Vector/Transforms/VectorLinearize.cpp | 48 ++++++++++++++++++-
mlir/test/Dialect/Vector/linearize.mlir | 41 ++++++++++++++++
2 files changed, 87 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 1b656d82f3201..9b2f88b7bbe9d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -817,6 +817,50 @@ struct LinearizeVectorToElements final
}
};
+/// Convert broadcasts from scalars or 1-element vectors, such as
+///
+/// ```mlir
+/// vector.broadcast %value : f32 to vector<4x4xf32>
+/// ```
+///
+/// to broadcasts to rank-1 vectors, with shape_casts before/after as needed.
+/// The above becomes,
+///
+/// ```mlir
+/// %out_1d = vector.splat %value : f32 to vector<16xf32>
+/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// ```
+struct LinearizeVectorBroadcast final
+ : public OpConversionPattern<vector::BroadcastOp> {
+ using Base::Base;
+
+ LinearizeVectorBroadcast(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ int numElements = 1;
+ Type sourceType = broadcastOp.getSourceType();
+ if (auto vecType = dyn_cast<VectorType>(sourceType)) {
+ numElements = vecType.getNumElements();
+ }
+
+ if (numElements != 1) {
+ return rewriter.notifyMatchFailure(
+ broadcastOp, "only broadcasts of single elements can be linearized.");
+ }
+
+ auto dstTy = getTypeConverter()->convertType(broadcastOp.getType());
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(broadcastOp, dstTy,
+ adaptor.getSource());
+
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -909,8 +953,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
- LinearizeVectorFromElements, LinearizeVectorToElements>(
- typeConverter, patterns.getContext());
+ LinearizeVectorBroadcast, LinearizeVectorFromElements,
+ LinearizeVectorToElements>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index ee5cfbcda5c19..cbbc833d7a51d 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -428,6 +428,47 @@ func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
// -----
+// CHECK-LABEL: linearize_vector_broadcast_scalar_source
+// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
+func.func @linearize_vector_broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> {
+
+ // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<8xi32>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32>
+ // CHECK: return %[[CAST]] : vector<4x2xi32>
+ %0 = vector.broadcast %arg0 : i32 to vector<4x2xi32>
+ return %0 : vector<4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: linearize_vector_broadcast_rank_two_source
+// CHECK-SAME: (%[[ARG:.*]]: vector<1x1xi32>) -> vector<4x2xi32>
+func.func @linearize_vector_broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> {
+
+ // CHECK: %[[CAST0:.*]] = vector.shape_cast %[[ARG]] : vector<1x1xi32> to vector<1xi32>
+ // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST0]] : vector<1xi32> to vector<8xi32>
+ // CHECK: %[[CAST1:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32>
+ // CHECK: return %[[CAST1]] : vector<4x2xi32>
+ %0 = vector.broadcast %arg0 : vector<1x1xi32> to vector<4x2xi32>
+ return %0 : vector<4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: linearize_scalable_vector_broadcast
+// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
+func.func @linearize_scalable_vector_broadcast(%arg0: i32) -> vector<4x[2]xi32> {
+
+ // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<[8]xi32>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<[8]xi32> to vector<4x[2]xi32>
+ // CHECK: return %[[CAST]] : vector<4x[2]xi32>
+ %0 = vector.broadcast %arg0 : i32 to vector<4x[2]xi32>
+ return %0 : vector<4x[2]xi32>
+
+}
+
+// -----
+
// CHECK-LABEL: linearize_create_mask
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {
>From 1e44d1f6f5cadc21b905a3c93bb624daa494a2f0 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 17 Oct 2025 16:23:39 -0700
Subject: [PATCH 2/2] update comment
Signed-off-by: James Newling <james.newling at gmail.com>
---
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 9b2f88b7bbe9d..ea93085849e0b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -827,7 +827,7 @@ struct LinearizeVectorToElements final
/// The above becomes,
///
/// ```mlir
-/// %out_1d = vector.splat %value : f32 to vector<16xf32>
+/// %out_1d = vector.broadcast %value : f32 to vector<16xf32>
/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
/// ```
struct LinearizeVectorBroadcast final
More information about the Mlir-commits
mailing list