[Mlir-commits] [mlir] [MLIR][Vector] Added ToElementsOp::fold for broadcast->to_elements pattern rewrite. (PR #160318)

Kunwar Grover llvmlistbot at llvm.org
Tue Sep 30 02:08:33 PDT 2025


================
@@ -2410,6 +2441,88 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
   return success();
 }
 
+namespace {
+
+struct ToElementsOfVectorBroadcast final
+    : public OpRewritePattern<ToElementsOp> {
+  using OpRewritePattern<ToElementsOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
+                                PatternRewriter &rewriter) const override {
+    auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+    if (!bcastOp)
+      return failure();
+
+    // Only handle broadcasts from a vector source here.
+    auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
+    if (!srcType)
+      return failure();
+
+    auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
+
+    // Bail on scalable vectors.
+    if (srcType.getNumScalableDims() != 0 || dstType.getNumScalableDims() != 0)
+      return failure();
+
+    ArrayRef<int64_t> dstShape = dstType.getShape();
+    ArrayRef<int64_t> srcShape = srcType.getShape();
+
+    unsigned dstRank = dstShape.size();
+    unsigned srcRank = srcShape.size();
+    if (srcRank > dstRank)
+      return failure();
+
+    // Verify broadcastability (right-aligned)
+    for (unsigned i = 0; i < dstRank; ++i) {
+      int64_t dstDim = dstShape[i];
+      int64_t srcDim = 1;
+      if (i + srcRank >= dstRank)
+        srcDim = srcShape[i + srcRank - dstRank];
+      if (!(srcDim == 1 || srcDim == dstDim))
+        return failure();
+    }
+
+    // Create elements for the broadcast source vector.
+    auto loc = toElementsOp.getLoc();
+    auto srcElems = rewriter.create<ToElementsOp>(loc, bcastOp.getSource());
+
+    int64_t dstCount = 1;
+    for (int64_t v : dstShape)
+      dstCount *= v;
+
+    SmallVector<Value> replacements;
+    replacements.reserve(dstCount);
+
+    // Pre-compute and apply mapping from destination linear index to
+    // source linear index (row-major, right-aligned broadcasting).
+    SmallVector<int64_t> dstIdx(dstShape.size());
+    for (int64_t lin = 0; lin < dstCount; ++lin) {
+      int64_t temp = lin;
+      for (int64_t i = dstShape.size() - 1; i >= 0; --i) {
+        int64_t dim = dstShape[i];
+        dstIdx[i] = temp % dim;
+        temp /= dim;
+      }
+      int64_t srcLin = 0;
+      for (unsigned k = 0; k < srcRank; ++k)
+        srcLin = srcLin * srcShape[k] +
+                 ((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);
+
+      replacements.push_back(srcElems.getResult(srcLin));
+    }
+
+    rewriter.replaceOp(toElementsOp, replacements);
+    return success();
+  }
+};
+
+} // end anonymous namespace
----------------
Groverkss wrote:

Is this how other files do it? I don't think so, usually it's just `// namespace`. Can you check if your ide is configured the way llvm expects it

https://github.com/llvm/llvm-project/pull/160318


More information about the Mlir-commits mailing list