[Mlir-commits] [mlir] [mlir][vector] add unroll pattern for broadcast (PR #142011)
Chao Chen
llvmlistbot at llvm.org
Thu Jun 5 08:21:31 PDT 2025
================
@@ -631,14 +631,78 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
+ UnrollBroadcastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::BroadcastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, broadcastOp);
+ if (!targetShape)
+ return failure();
+
+ Location loc = broadcastOp.getLoc();
+ VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ VectorType resType = broadcastOp.getResultVectorType();
+ VectorType newType =
+ resType.cloneWith(*targetShape, resType.getElementType());
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resType, rewriter.getZeroAttr(resType));
+
+ SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
+ SmallVector<int64_t> strides(originalShape.size(), 1);
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ Value newSrc;
+ if (!srcType) {
+ // Scalar to vector broadcast.
+ newSrc = broadcastOp.getSource();
+ } else {
+ // Vector to vector broadcast.
+ int64_t rank = srcType.getRank();
+ SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
+ SmallVector<int64_t> srcShape(targetShape->end() - rank,
+ targetShape->end());
+ SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
+ // addjust the offset and shape for src if the corresponding dim is 1.
----------------
chencha3 wrote:
Thanks, fixed.
https://github.com/llvm/llvm-project/pull/142011
More information about the Mlir-commits
mailing list