[Mlir-commits] [mlir] [MLIR][Vector] Added ToElementsOp::fold for broadcast->to_elements pattern rewrite. (PR #160318)
Kunwar Grover
llvmlistbot at llvm.org
Tue Sep 30 05:08:40 PDT 2025
================
@@ -2410,6 +2443,91 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
return success();
}
+/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
+/// vector:
+/// - Build `vector.to_elements %v` and remap each destination element to the
+/// corresponding source element using broadcast rules (match or 1 →
+/// replicate).
+///
+/// Example:
+/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
+/// %e:6 = vector.to_elements %v : vector<3x2xf32>
+/// becomes:
+/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
+/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
+/// // %src_elems#1, %src_elems#0, %src_elems#1
+
+class ToElementsOfBroadcast final : public OpRewritePattern<ToElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
+ PatternRewriter &rewriter) const override {
+ auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+ if (!bcastOp)
+ return failure();
+
+ // Only handle broadcasts from a vector source here.
+ auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
+ if (!srcType)
+ return failure();
+
+ auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
+
+ // Bail on scalable vectors.
+ if (srcType.getNumScalableDims() != 0 || dstType.getNumScalableDims() != 0)
+ return failure();
+
+ ArrayRef<int64_t> dstShape = dstType.getShape();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+
+ int64_t dstRank = dstShape.size();
+ int64_t srcRank = srcShape.size();
+ if (srcRank > dstRank)
+ return failure();
+
+ // Create elements for the broadcast source vector.
+ auto srcElems = rewriter.create<ToElementsOp>(toElementsOp.getLoc(),
+ bcastOp.getSource());
+
+ int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1,
+ std::multiplies<int64_t>());
+
+ SmallVector<Value> replacements;
+ replacements.reserve(dstCount);
+
+ // For each element of the destination, determine which element of the
+ // source should be used. We walk all destination positions using a single
+ // counter, decode it into per-dimension indices, then build the matching
+ // source position: use the same index where sizes match, and use 0 where
+ // the source size is 1 (replication). This mapping is needed so we can
+ // replace each result of to_elements with the corresponding element from
+ // the broadcast source.
+ SmallVector<int64_t> dstIdx(dstShape.size());
+ for (int64_t lin = 0; lin < dstCount; ++lin) {
+ int64_t temp = lin;
+ for (int64_t i = dstShape.size() - 1; i >= 0; --i) {
+ int64_t dim = dstShape[i];
+ dstIdx[i] = temp % dim;
+ temp /= dim;
+ }
+ int64_t srcLin = 0;
+ for (int64_t k = 0; k < srcRank; ++k)
+ srcLin = srcLin * srcShape[k] +
+ ((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);
+
+ replacements.push_back(srcElems.getResult(srcLin));
+ }
+
+ rewriter.replaceOp(toElementsOp, replacements);
+ return success();
+ }
+};
----------------
Groverkss wrote:
Do you really need this decoding? You can just do ToElementsOp on the source vector, and just replicate it n times, where n is the broadcasted number of elements.
https://github.com/llvm/llvm-project/pull/160318
More information about the Mlir-commits
mailing list