[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of shape_cast (PR #150523)

Kunwar Grover llvmlistbot at llvm.org
Fri Aug 1 02:17:46 PDT 2025


================
@@ -2938,13 +2938,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type.
+struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+    if (!srcShapeCast)
+      return failure();
+
+    VectorType srcType = srcShapeCast.getSourceVectorType();
+    VectorType destType = broadcastOp.getResultVectorType();
+    if (vector::isBroadcastableTo(srcType, destType) !=
+        BroadcastableToResult::Success)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+                                             srcShapeCast.getSource());
+    return success();
+  }
+};
----------------
Groverkss wrote:

This should be a folder, not a rewrite pattern.

https://github.com/llvm/llvm-project/pull/150523


More information about the Mlir-commits mailing list