[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of shape_cast (PR #150523)
Kunwar Grover
llvmlistbot at llvm.org
Fri Aug 1 02:17:46 PDT 2025
================
@@ -2938,13 +2938,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type.
+struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+ if (!srcShapeCast)
+ return failure();
+
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ if (vector::isBroadcastableTo(srcType, destType) !=
+ BroadcastableToResult::Success)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+ srcShapeCast.getSource());
+ return success();
+ }
+};
----------------
Groverkss wrote:
This should be a folder, not a rewrite pattern.
https://github.com/llvm/llvm-project/pull/150523
More information about the Mlir-commits
mailing list