[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of shape_cast (PR #150523)

Min-Yih Hsu llvmlistbot at llvm.org
Fri Jul 25 09:08:15 PDT 2025


https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/150523

>From 9ca07a1022b7421e740390dff3e5aa2046a24e61 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Thu, 24 Jul 2025 13:55:56 -0700
Subject: [PATCH 1/3] [mlir][vector] Canonicalize broadcast of shape_cast

Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is
compatible with broadcast's result type.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 24 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 22 ++++++++++++++++++++
 2 files changed, 45 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c97aed6e7742..ad908319d8584 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2938,13 +2938,35 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type.
+struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    if (auto srcShapeCast =
+            broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
+      VectorType srcType = srcShapeCast.getSourceVectorType();
+      VectorType destType = broadcastOp.getResultVectorType();
+      if (vector::isBroadcastableTo(srcType, destType) ==
+          BroadcastableToResult::Success) {
+        rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+                                                 srcShapeCast.getSource());
+        return success();
+      }
+    }
+    return failure();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
   // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1461c30162c5f..0fd2acd06c8ec 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1168,6 +1168,28 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
 
 // -----
 
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
+//   CHECK-NOT:   vector.shape_cast
+//       CHECK:   vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
+func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
+  %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
+  return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
+//       CHECK:   vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
+//       CHECK:   vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
+func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+  %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
+  %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
+  return %1 : vector<2x4x16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index

>From 10a914efacadd06d8dc40c266c1a85416d546782 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Fri, 25 Jul 2025 09:06:35 -0700
Subject: [PATCH 2/3] fixup! Address review comments

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 25 ++++++++++++------------
 1 file changed, 13 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ad908319d8584..348c713980ef6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2946,18 +2946,19 @@ struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
 
   LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
                                 PatternRewriter &rewriter) const override {
-    if (auto srcShapeCast =
-            broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
-      VectorType srcType = srcShapeCast.getSourceVectorType();
-      VectorType destType = broadcastOp.getResultVectorType();
-      if (vector::isBroadcastableTo(srcType, destType) ==
-          BroadcastableToResult::Success) {
-        rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
-                                                 srcShapeCast.getSource());
-        return success();
-      }
-    }
-    return failure();
+    auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+    if (!srcShapeCast)
+      return failure();
+
+    VectorType srcType = srcShapeCast.getSourceVectorType();
+    VectorType destType = broadcastOp.getResultVectorType();
+    if (vector::isBroadcastableTo(srcType, destType) !=
+        BroadcastableToResult::Success)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+                                             srcShapeCast.getSource());
+    return success();
   }
 };
 } // namespace

>From 067f1150c3b6ea87cd9b09f64949b92d22087c28 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min at myhsu.dev>
Date: Fri, 25 Jul 2025 09:08:06 -0700
Subject: [PATCH 3/3] fixup! Update mlir/test/Dialect/Vector/canonicalize.mlir
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Andrzej WarzyƄski <andrzej.warzynski at gmail.com>
---
 mlir/test/Dialect/Vector/canonicalize.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0fd2acd06c8ec..fc4ef6bf39379 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1182,7 +1182,7 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>)
 // CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
 //       CHECK:   vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
 //       CHECK:   vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
-func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
   %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
   %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
   return %1 : vector<2x4x16xf32>



More information about the Mlir-commits mailing list