[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of shape_cast (PR #150523)

Min-Yih Hsu llvmlistbot at llvm.org
Fri Jul 25 09:07:38 PDT 2025


================
@@ -2938,13 +2938,35 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type.
+struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    if (auto srcShapeCast =
+            broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
+      VectorType srcType = srcShapeCast.getSourceVectorType();
+      VectorType destType = broadcastOp.getResultVectorType();
+      if (vector::isBroadcastableTo(srcType, destType) ==
+          BroadcastableToResult::Success) {
+        rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+                                                 srcShapeCast.getSource());
+        return success();
+      }
+    }
+    return failure();
----------------
mshockwave wrote:

Fixed.

https://github.com/llvm/llvm-project/pull/150523


More information about the Mlir-commits mailing list