[Mlir-commits] [mlir] [MLIR][Vector] Added ToElementsOp::fold for broadcast->to_elements pattern rewrite. (PR #160318)
Keshav Vinayak Jha
llvmlistbot at llvm.org
Tue Sep 30 04:21:30 PDT 2025
================
@@ -2410,6 +2441,88 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
return success();
}
+namespace {
+
+struct ToElementsOfVectorBroadcast final
+ : public OpRewritePattern<ToElementsOp> {
+ using OpRewritePattern<ToElementsOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
+ PatternRewriter &rewriter) const override {
+ auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+ if (!bcastOp)
+ return failure();
+
+ // Only handle broadcasts from a vector source here.
+ auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
+ if (!srcType)
+ return failure();
+
+ auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
+
+ // Bail on scalable vectors.
+ if (srcType.getNumScalableDims() != 0 || dstType.getNumScalableDims() != 0)
+ return failure();
+
+ ArrayRef<int64_t> dstShape = dstType.getShape();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+
+ unsigned dstRank = dstShape.size();
+ unsigned srcRank = srcShape.size();
+ if (srcRank > dstRank)
+ return failure();
+
+ // Verify broadcastability (right-aligned)
----------------
keshavvinayak01 wrote:
This check is pointless tbh, I don't know why I added it earlier. Thanks for being thorough.
https://github.com/llvm/llvm-project/pull/160318
More information about the Mlir-commits
mailing list