[Mlir-commits] [mlir] [mlir][vector] add unroll pattern for broadcast (PR #142011)

Chao Chen llvmlistbot at llvm.org
Thu Jun 5 08:21:31 PDT 2025


================
@@ -631,14 +631,78 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
+  UnrollBroadcastPattern(MLIRContext *context,
+                         const vector::UnrollVectorOptions &options,
+                         PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::BroadcastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, broadcastOp);
+    if (!targetShape)
+      return failure();
+
+    Location loc = broadcastOp.getLoc();
+    VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    VectorType resType = broadcastOp.getResultVectorType();
+    VectorType newType =
+        resType.cloneWith(*targetShape, resType.getElementType());
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resType, rewriter.getZeroAttr(resType));
+
+    SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
+    SmallVector<int64_t> strides(originalShape.size(), 1);
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, *targetShape)) {
+      Value newSrc;
+      if (!srcType) {
+        // Scalar to vector broadcast.
+        newSrc = broadcastOp.getSource();
+      } else {
+        // Vector to vector broadcast.
+        int64_t rank = srcType.getRank();
+        SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
+        SmallVector<int64_t> srcShape(targetShape->end() - rank,
+                                      targetShape->end());
+        SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
+        // addjust the offset and shape for src if the corresponding dim is 1.
----------------
chencha3 wrote:

Thanks, fixed. 

https://github.com/llvm/llvm-project/pull/142011


More information about the Mlir-commits mailing list