[Mlir-commits] [mlir] [mlir][vector] add unroll pattern for broadcast (PR #142011)

Chao Chen llvmlistbot at llvm.org
Thu Jun 5 08:22:48 PDT 2025


================
@@ -631,14 +631,78 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
+  UnrollBroadcastPattern(MLIRContext *context,
+                         const vector::UnrollVectorOptions &options,
+                         PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::BroadcastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, broadcastOp);
+    if (!targetShape)
+      return failure();
+
+    Location loc = broadcastOp.getLoc();
+    VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    VectorType resType = broadcastOp.getResultVectorType();
+    VectorType newType =
+        resType.cloneWith(*targetShape, resType.getElementType());
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resType, rewriter.getZeroAttr(resType));
+
+    SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
+    SmallVector<int64_t> strides(originalShape.size(), 1);
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, *targetShape)) {
----------------
chencha3 wrote:

Good question. It seems not applicable to me. But I am not quite sure. Correct me if I am wrong.

https://github.com/llvm/llvm-project/pull/142011


More information about the Mlir-commits mailing list