[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of shape_cast (PR #150523)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Jul 25 08:21:37 PDT 2025
================
@@ -2938,13 +2938,35 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type.
+struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ if (auto srcShapeCast =
+ broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ if (vector::isBroadcastableTo(srcType, destType) ==
+ BroadcastableToResult::Success) {
+ rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+ srcShapeCast.getSource());
+ return success();
+ }
+ }
+ return failure();
----------------
banach-space wrote:
[nit] Prefer [early exits](https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code) - helps reduce indentation.
```suggestion
auto srcShapeCast =
broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
if (!srcShapeCast)
return failure();
VectorType srcType = srcShapeCast.getSourceVectorType();
VectorType destType = broadcastOp.getResultVectorType();
if (vector::isBroadcastableTo(srcType, destType) !=
BroadcastableToResult::Success)
return failure();
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
srcShapeCast.getSource());
return success();
```
https://github.com/llvm/llvm-project/pull/150523
More information about the Mlir-commits
mailing list