[Mlir-commits] [mlir] [MLIR][Vector] Added ToElementsOp::fold for broadcast->to_elements pattern rewrite. (PR #160318)

Jakub Kuderski llvmlistbot at llvm.org
Wed Oct 1 05:18:55 PDT 2025


================
@@ -2410,6 +2440,94 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
   return success();
 }
 
+/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
+/// vector.
+/// - Build `vector.to_elements %v` and remap each destination element to the
+///   corresponding source element using broadcast rules (match or 1 →
+///   replicate).
+///
+/// Example:
+///   %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
+///   %e:6 = vector.to_elements %v : vector<3x2xf32>
+/// becomes:
+///   %src_elems:2 = vector.to_elements %src : vector<2xf32>
+///   // uses: %src_elems#0, %src_elems#1, %src_elems#0,
+///   //       %src_elems#1, %src_elems#0, %src_elems#1
+class ToElementsOfBroadcast final : public OpRewritePattern<ToElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
+                                PatternRewriter &rewriter) const override {
+    auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+    if (!bcastOp)
+      return failure();
+
+    // Only handle broadcasts from a vector source here.
+    auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
+    if (!srcType)
+      return failure();
+
+    auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
+
+    ArrayRef<int64_t> dstShape = dstType.getShape();
+    ArrayRef<int64_t> srcShape = srcType.getShape();
+
+    int64_t dstRank = dstShape.size();
+    int64_t srcRank = srcShape.size();
+
+    // Create elements for the broadcast source vector.
+    auto srcElems = vector::ToElementsOp::create(
+        rewriter, toElementsOp.getLoc(), bcastOp.getSource());
+
+    int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1,
+                                       std::multiplies<int64_t>());
+
+    SmallVector<Value> replacements;
+    replacements.reserve(dstCount);
+
+    // For each element of the destination, determine which element of the
+    // source should be used. We walk all destination positions using a single
+    // counter, decode it into per-dimension indices, then build the matching
+    // source position: use the same index where sizes match, and use 0 where
+    // the source size is 1 (replication). This mapping is needed so we can
+    // replace each result of to_elements with the corresponding element from
+    // the broadcast source.
+    // Inner-dimension stretch example:
+    //   %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
+    //   %e:12 = vector.to_elements %v : vector<2x3x2xf32>
+    // becomes:
+    //   %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
+    //   // uses: %src_elems#0, %src_elems#1, %src_elems#0,
+    //   //       %src_elems#1, %src_elems#0, %src_elems#1,
+    //   //       %src_elems#2, %src_elems#3, %src_elems#2,
+    //   //       %src_elems#3, %src_elems#2, %src_elems#3
+
+    SmallVector<int64_t> dstIdx(dstShape.size());
+    for (int64_t lin = 0; lin < dstCount; ++lin) {
+      int64_t temp = lin;
+      for (int64_t i = dstShape.size() - 1; i >= 0; --i) {
+        int64_t dim = dstShape[i];
+        dstIdx[i] = temp % dim;
+        temp /= dim;
+      }
+      int64_t srcLin = 0;
+      for (int64_t k = 0; k < srcRank; ++k)
+        srcLin = srcLin * srcShape[k] +
+                 ((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);
+
----------------
kuhar wrote:

Are there no existing helpers we could use for this? We might have similar code for insert/extract folders, but I don't know if we can reuse it here. 

https://github.com/llvm/llvm-project/pull/160318


More information about the Mlir-commits mailing list