[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of	shape_cast (PR #150523)
    Min-Yih Hsu 
    llvmlistbot at llvm.org
       
    Thu Aug  7 16:26:51 PDT 2025
    
    
  
https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/150523
>From 9ca07a1022b7421e740390dff3e5aa2046a24e61 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Thu, 24 Jul 2025 13:55:56 -0700
Subject: [PATCH 1/7] [mlir][vector] Canonicalize broadcast of shape_cast
Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is
compatible with broadcast's result type.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 24 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 22 ++++++++++++++++++++
 2 files changed, 45 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c97aed6e7742..ad908319d8584 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2938,13 +2938,35 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type.
+struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    if (auto srcShapeCast =
+            broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
+      VectorType srcType = srcShapeCast.getSourceVectorType();
+      VectorType destType = broadcastOp.getResultVectorType();
+      if (vector::isBroadcastableTo(srcType, destType) ==
+          BroadcastableToResult::Success) {
+        rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+                                                 srcShapeCast.getSource());
+        return success();
+      }
+    }
+    return failure();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
   // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1461c30162c5f..0fd2acd06c8ec 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1168,6 +1168,28 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
 
 // -----
 
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
+//   CHECK-NOT:   vector.shape_cast
+//       CHECK:   vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
+func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
+  %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
+  return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
+//       CHECK:   vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
+//       CHECK:   vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
+func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+  %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
+  %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
+  return %1 : vector<2x4x16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
>From 10a914efacadd06d8dc40c266c1a85416d546782 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Fri, 25 Jul 2025 09:06:35 -0700
Subject: [PATCH 2/7] fixup! Address review comments
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 25 ++++++++++++------------
 1 file changed, 13 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ad908319d8584..348c713980ef6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2946,18 +2946,19 @@ struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
 
   LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
                                 PatternRewriter &rewriter) const override {
-    if (auto srcShapeCast =
-            broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
-      VectorType srcType = srcShapeCast.getSourceVectorType();
-      VectorType destType = broadcastOp.getResultVectorType();
-      if (vector::isBroadcastableTo(srcType, destType) ==
-          BroadcastableToResult::Success) {
-        rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
-                                                 srcShapeCast.getSource());
-        return success();
-      }
-    }
-    return failure();
+    auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+    if (!srcShapeCast)
+      return failure();
+
+    VectorType srcType = srcShapeCast.getSourceVectorType();
+    VectorType destType = broadcastOp.getResultVectorType();
+    if (vector::isBroadcastableTo(srcType, destType) !=
+        BroadcastableToResult::Success)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+                                             srcShapeCast.getSource());
+    return success();
   }
 };
 } // namespace
>From 067f1150c3b6ea87cd9b09f64949b92d22087c28 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min at myhsu.dev>
Date: Fri, 25 Jul 2025 09:08:06 -0700
Subject: [PATCH 3/7] fixup! Update mlir/test/Dialect/Vector/canonicalize.mlir
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Andrzej WarzyĆski <andrzej.warzynski at gmail.com>
---
 mlir/test/Dialect/Vector/canonicalize.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0fd2acd06c8ec..fc4ef6bf39379 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1182,7 +1182,7 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>)
 // CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
 //       CHECK:   vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
 //       CHECK:   vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
-func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
   %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
   %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
   return %1 : vector<2x4x16xf32>
>From 32c870b8ad9bd285652b2606c8c31f800d4343f9 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Fri, 25 Jul 2025 13:13:29 -0700
Subject: [PATCH 4/7] fixup! fixup! Update
 mlir/test/Dialect/Vector/canonicalize.mlir
---
 mlir/test/Dialect/Vector/canonicalize.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index fc4ef6bf39379..776c75114ed44 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1179,7 +1179,7 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>)
 
 // -----
 
-// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
+// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
 //       CHECK:   vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
 //       CHECK:   vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
 func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
>From 0cf5cc19908b5b88a3a8d9775c4061ab8ca26f2c Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Tue, 5 Aug 2025 16:31:29 -0700
Subject: [PATCH 5/7] fixup! Fix invalid folding on mismatching broadcast
 dimensions
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 33 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 13 ++++++++-
 2 files changed, 44 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0bc62d832b403..2877527ae095a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2882,8 +2882,21 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
   }
 };
 
+// Return the broadcasted dimensions. Including broadcasts in the leading
+// dimensions and broadcasts through unit dimension (i.e. dim-1).
+static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
+                                    ArrayRef<int64_t> destShape) {
+  assert(destShape.size() >= srcShape.size());
+  BitVector broadcastedDims(destShape.size());
+  broadcastedDims.set(0, destShape.size() - srcShape.size());
+  auto unitDims = computeBroadcastedUnitDims(srcShape, destShape);
+  for (int64_t dim : unitDims)
+    broadcastedDims.set(dim);
+  return broadcastedDims;
+}
+
 // Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
-// with broadcast's result type.
+// with broadcast's result type and the broadcasted dimensions are the same.
 struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -2895,10 +2908,28 @@ struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
 
     VectorType srcType = srcShapeCast.getSourceVectorType();
     VectorType destType = broadcastOp.getResultVectorType();
+    // Check type compatibility.
     if (vector::isBroadcastableTo(srcType, destType) !=
         BroadcastableToResult::Success)
       return failure();
 
+    // Given
+    // ```
+    // %s = shape_cast(%x)
+    // %b = broadcast(%s)
+    // ```
+    // If we want to fold %x into %b, the broadcasted dimensions from %x to
+    // %b has to be the same as that of from %s to %b.
+    ArrayRef<int64_t> shapecastShape =
+        srcShapeCast.getResultVectorType().getShape();
+    ArrayRef<int64_t> srcShape = srcType.getShape();
+    ArrayRef<int64_t> destShape = destType.getShape();
+    BitVector origBroadcastedDims =
+        getBroadcastedDims(shapecastShape, destShape);
+    BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape);
+    if (newBroadcastedDims != origBroadcastedDims)
+      return failure();
+
     rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
                                              srcShapeCast.getSource());
     return success();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d2b3f9028b301..7c19d5ea41bfb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1180,7 +1180,7 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>)
 // -----
 
 // CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
-//       CHECK:   vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
+//       CHECK:   vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32>
 //       CHECK:   vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
 func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
   %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
@@ -1190,6 +1190,17 @@ func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vecto
 
 // -----
 
+// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims
+//       CHECK:   vector.shape_cast {{.+}} : vector<2x1xf32> to vector<1x2xf32>
+//       CHECK:   vector.broadcast {{.+}} : vector<1x2xf32> to vector<2x2xf32>
+func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%arg0 : vector<2x1xf32>) -> vector<2x2xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2xf32>
+  %1 = vector.broadcast %0 : vector<1x2xf32> to vector<2x2xf32>
+  return %1 : vector<2x2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
>From 236c5459f7c3256d11cf6dc8aabd0ab0da964261 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Tue, 5 Aug 2025 16:56:43 -0700
Subject: [PATCH 6/7] fixup! Rewrite as a folding pattern
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 106 +++++++++++------------
 1 file changed, 51 insertions(+), 55 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2877527ae095a..abdbe7581487e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2841,9 +2841,59 @@ LogicalResult BroadcastOp::verify() {
   llvm_unreachable("unexpected vector.broadcast op error");
 }
 
+// Return the broadcasted dimensions. Including broadcasts in the leading
+// dimensions and broadcasts through unit dimension (i.e. dim-1).
+static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
+                                    ArrayRef<int64_t> destShape) {
+  assert(destShape.size() >= srcShape.size());
+  BitVector broadcastedDims(destShape.size());
+  broadcastedDims.set(0, destShape.size() - srcShape.size());
+  auto unitDims = computeBroadcastedUnitDims(srcShape, destShape);
+  for (int64_t dim : unitDims)
+    broadcastedDims.set(dim);
+  return broadcastedDims;
+}
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type and the broadcasted dimensions are the same.
+static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
+  auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+  if (!srcShapeCast)
+    return failure();
+
+  VectorType srcType = srcShapeCast.getSourceVectorType();
+  VectorType destType = broadcastOp.getResultVectorType();
+  // Check type compatibility.
+  if (vector::isBroadcastableTo(srcType, destType) !=
+      BroadcastableToResult::Success)
+    return failure();
+
+  // Given
+  // ```
+  // %s = shape_cast(%x)
+  // %b = broadcast(%s)
+  // ```
+  // If we want to fold %x into %b, the broadcasted dimensions from %x to
+  // %b has to be the same as that of from %s to %b.
+  ArrayRef<int64_t> shapecastShape =
+      srcShapeCast.getResultVectorType().getShape();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  ArrayRef<int64_t> destShape = destType.getShape();
+  BitVector origBroadcastedDims = getBroadcastedDims(shapecastShape, destShape);
+  BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape);
+  if (newBroadcastedDims != origBroadcastedDims)
+    return failure();
+
+  broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
+  return success();
+}
+
 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
   if (getSourceType() == getResultVectorType())
     return getSource();
+  if (succeeded(foldBroadcastOfShapeCast(*this)))
+    return getResult();
+
   if (!adaptor.getSource())
     return {};
   auto vectorType = getResultVectorType();
@@ -2881,67 +2931,13 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
-
-// Return the broadcasted dimensions. Including broadcasts in the leading
-// dimensions and broadcasts through unit dimension (i.e. dim-1).
-static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
-                                    ArrayRef<int64_t> destShape) {
-  assert(destShape.size() >= srcShape.size());
-  BitVector broadcastedDims(destShape.size());
-  broadcastedDims.set(0, destShape.size() - srcShape.size());
-  auto unitDims = computeBroadcastedUnitDims(srcShape, destShape);
-  for (int64_t dim : unitDims)
-    broadcastedDims.set(dim);
-  return broadcastedDims;
-}
-
-// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
-// with broadcast's result type and the broadcasted dimensions are the same.
-struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
-                                PatternRewriter &rewriter) const override {
-    auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
-    if (!srcShapeCast)
-      return failure();
-
-    VectorType srcType = srcShapeCast.getSourceVectorType();
-    VectorType destType = broadcastOp.getResultVectorType();
-    // Check type compatibility.
-    if (vector::isBroadcastableTo(srcType, destType) !=
-        BroadcastableToResult::Success)
-      return failure();
-
-    // Given
-    // ```
-    // %s = shape_cast(%x)
-    // %b = broadcast(%s)
-    // ```
-    // If we want to fold %x into %b, the broadcasted dimensions from %x to
-    // %b has to be the same as that of from %s to %b.
-    ArrayRef<int64_t> shapecastShape =
-        srcShapeCast.getResultVectorType().getShape();
-    ArrayRef<int64_t> srcShape = srcType.getShape();
-    ArrayRef<int64_t> destShape = destType.getShape();
-    BitVector origBroadcastedDims =
-        getBroadcastedDims(shapecastShape, destShape);
-    BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape);
-    if (newBroadcastedDims != origBroadcastedDims)
-      return failure();
-
-    rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
-                                             srcShapeCast.getSource());
-    return success();
-  }
-};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
   // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
+  results.add<BroadcastFolder>(context);
 }
 
 //===----------------------------------------------------------------------===//
>From e370b81aa9830798c1b968b164fafcb8e61a77eb Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Thu, 7 Aug 2025 16:26:10 -0700
Subject: [PATCH 7/7] fixup! Simplify the algorithm for the legality check
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 29 ++++++++++++------------
 1 file changed, 15 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index abdbe7581487e..1d49442775fb8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2842,7 +2842,7 @@ LogicalResult BroadcastOp::verify() {
 }
 
 // Return the broadcasted dimensions. Including broadcasts in the leading
-// dimensions and broadcasts through unit dimension (i.e. dim-1).
+// dimensions and broadcasts through unit dimension.
 static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
                                     ArrayRef<int64_t> destShape) {
   assert(destShape.size() >= srcShape.size());
@@ -2855,7 +2855,8 @@ static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
 }
 
 // Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
-// with broadcast's result type and the broadcasted dimensions are the same.
+// with broadcast's result type and shape_cast only adds or removes ones in the
+// leading dimensions.
 static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
   auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
   if (!srcShapeCast)
@@ -2868,22 +2869,22 @@ static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
       BroadcastableToResult::Success)
     return failure();
 
-  // Given
-  // ```
-  // %s = shape_cast(%x)
-  // %b = broadcast(%s)
-  // ```
-  // If we want to fold %x into %b, the broadcasted dimensions from %x to
-  // %b has to be the same as that of from %s to %b.
+  ArrayRef<int64_t> srcShape = srcType.getShape();
   ArrayRef<int64_t> shapecastShape =
       srcShapeCast.getResultVectorType().getShape();
-  ArrayRef<int64_t> srcShape = srcType.getShape();
-  ArrayRef<int64_t> destShape = destType.getShape();
-  BitVector origBroadcastedDims = getBroadcastedDims(shapecastShape, destShape);
-  BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape);
-  if (newBroadcastedDims != origBroadcastedDims)
+  // Trailing dimensions should be the same if shape_cast only alters the
+  // leading dimensions.
+  unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
+  if (!llvm::equal(srcShape.take_back(numTrailingDims),
+                   shapecastShape.take_back(numTrailingDims)))
     return failure();
 
+  assert(all_of(srcShape.drop_back(numTrailingDims),
+                [](int64_t E) { return E == 1; }) &&
+         all_of(shapecastShape.drop_back(numTrailingDims),
+                [](int64_t E) { return E == 1; }) &&
+         "ill-formed shape_cast");
+
   broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
   return success();
 }
    
    
More information about the Mlir-commits
mailing list