[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of shape_cast (PR #150523)
Min-Yih Hsu
llvmlistbot at llvm.org
Thu Aug 7 16:35:44 PDT 2025
https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/150523
>From 9ca07a1022b7421e740390dff3e5aa2046a24e61 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Thu, 24 Jul 2025 13:55:56 -0700
Subject: [PATCH 1/8] [mlir][vector] Canonicalize broadcast of shape_cast
Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is
compatible with broadcast's result type.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 24 +++++++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 22 ++++++++++++++++++++
2 files changed, 45 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c97aed6e7742..ad908319d8584 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2938,13 +2938,35 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type.
+struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ if (auto srcShapeCast =
+ broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ if (vector::isBroadcastableTo(srcType, destType) ==
+ BroadcastableToResult::Success) {
+ rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+ srcShapeCast.getSource());
+ return success();
+ }
+ }
+ return failure();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder>(context);
+ results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1461c30162c5f..0fd2acd06c8ec 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1168,6 +1168,28 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
// -----
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
+// CHECK-NOT: vector.shape_cast
+// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
+func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
+ %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
+ return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
+// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
+// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
+func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+ %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
+ %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
+ return %1 : vector<2x4x16xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
>From 10a914efacadd06d8dc40c266c1a85416d546782 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Fri, 25 Jul 2025 09:06:35 -0700
Subject: [PATCH 2/8] fixup! Address review comments
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 25 ++++++++++++------------
1 file changed, 13 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ad908319d8584..348c713980ef6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2946,18 +2946,19 @@ struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
- if (auto srcShapeCast =
- broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
- VectorType srcType = srcShapeCast.getSourceVectorType();
- VectorType destType = broadcastOp.getResultVectorType();
- if (vector::isBroadcastableTo(srcType, destType) ==
- BroadcastableToResult::Success) {
- rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
- srcShapeCast.getSource());
- return success();
- }
- }
- return failure();
+ auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+ if (!srcShapeCast)
+ return failure();
+
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ if (vector::isBroadcastableTo(srcType, destType) !=
+ BroadcastableToResult::Success)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+ srcShapeCast.getSource());
+ return success();
}
};
} // namespace
>From 067f1150c3b6ea87cd9b09f64949b92d22087c28 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min at myhsu.dev>
Date: Fri, 25 Jul 2025 09:08:06 -0700
Subject: [PATCH 3/8] fixup! Update mlir/test/Dialect/Vector/canonicalize.mlir
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Andrzej WarzyĆski <andrzej.warzynski at gmail.com>
---
mlir/test/Dialect/Vector/canonicalize.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0fd2acd06c8ec..fc4ef6bf39379 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1182,7 +1182,7 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>)
// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
-func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
%0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
%1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
return %1 : vector<2x4x16xf32>
>From 32c870b8ad9bd285652b2606c8c31f800d4343f9 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Fri, 25 Jul 2025 13:13:29 -0700
Subject: [PATCH 4/8] fixup! fixup! Update
mlir/test/Dialect/Vector/canonicalize.mlir
---
mlir/test/Dialect/Vector/canonicalize.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index fc4ef6bf39379..776c75114ed44 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1179,7 +1179,7 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>)
// -----
-// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
+// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
>From 0cf5cc19908b5b88a3a8d9775c4061ab8ca26f2c Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Tue, 5 Aug 2025 16:31:29 -0700
Subject: [PATCH 5/8] fixup! Fix invalid folding on mismatching broadcast
dimensions
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 33 +++++++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 13 ++++++++-
2 files changed, 44 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0bc62d832b403..2877527ae095a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2882,8 +2882,21 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
}
};
+// Return the broadcasted dimensions. Including broadcasts in the leading
+// dimensions and broadcasts through unit dimension (i.e. dim-1).
+static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
+ ArrayRef<int64_t> destShape) {
+ assert(destShape.size() >= srcShape.size());
+ BitVector broadcastedDims(destShape.size());
+ broadcastedDims.set(0, destShape.size() - srcShape.size());
+ auto unitDims = computeBroadcastedUnitDims(srcShape, destShape);
+ for (int64_t dim : unitDims)
+ broadcastedDims.set(dim);
+ return broadcastedDims;
+}
+
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
-// with broadcast's result type.
+// with broadcast's result type and the broadcasted dimensions are the same.
struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
using OpRewritePattern::OpRewritePattern;
@@ -2895,10 +2908,28 @@ struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
VectorType srcType = srcShapeCast.getSourceVectorType();
VectorType destType = broadcastOp.getResultVectorType();
+ // Check type compatibility.
if (vector::isBroadcastableTo(srcType, destType) !=
BroadcastableToResult::Success)
return failure();
+ // Given
+ // ```
+ // %s = shape_cast(%x)
+ // %b = broadcast(%s)
+ // ```
+ // If we want to fold %x into %b, the broadcasted dimensions from %x to
+ // %b has to be the same as that of from %s to %b.
+ ArrayRef<int64_t> shapecastShape =
+ srcShapeCast.getResultVectorType().getShape();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ ArrayRef<int64_t> destShape = destType.getShape();
+ BitVector origBroadcastedDims =
+ getBroadcastedDims(shapecastShape, destShape);
+ BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape);
+ if (newBroadcastedDims != origBroadcastedDims)
+ return failure();
+
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
srcShapeCast.getSource());
return success();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d2b3f9028b301..7c19d5ea41bfb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1180,7 +1180,7 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>)
// -----
// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
-// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
+// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32>
// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
%0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
@@ -1190,6 +1190,17 @@ func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vecto
// -----
+// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims
+// CHECK: vector.shape_cast {{.+}} : vector<2x1xf32> to vector<1x2xf32>
+// CHECK: vector.broadcast {{.+}} : vector<1x2xf32> to vector<2x2xf32>
+func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%arg0 : vector<2x1xf32>) -> vector<2x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2xf32>
+ %1 = vector.broadcast %0 : vector<1x2xf32> to vector<2x2xf32>
+ return %1 : vector<2x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
>From 236c5459f7c3256d11cf6dc8aabd0ab0da964261 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Tue, 5 Aug 2025 16:56:43 -0700
Subject: [PATCH 6/8] fixup! Rewrite as a folding pattern
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 106 +++++++++++------------
1 file changed, 51 insertions(+), 55 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2877527ae095a..abdbe7581487e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2841,9 +2841,59 @@ LogicalResult BroadcastOp::verify() {
llvm_unreachable("unexpected vector.broadcast op error");
}
+// Return the broadcasted dimensions. Including broadcasts in the leading
+// dimensions and broadcasts through unit dimension (i.e. dim-1).
+static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
+ ArrayRef<int64_t> destShape) {
+ assert(destShape.size() >= srcShape.size());
+ BitVector broadcastedDims(destShape.size());
+ broadcastedDims.set(0, destShape.size() - srcShape.size());
+ auto unitDims = computeBroadcastedUnitDims(srcShape, destShape);
+ for (int64_t dim : unitDims)
+ broadcastedDims.set(dim);
+ return broadcastedDims;
+}
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type and the broadcasted dimensions are the same.
+static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
+ auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+ if (!srcShapeCast)
+ return failure();
+
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ // Check type compatibility.
+ if (vector::isBroadcastableTo(srcType, destType) !=
+ BroadcastableToResult::Success)
+ return failure();
+
+ // Given
+ // ```
+ // %s = shape_cast(%x)
+ // %b = broadcast(%s)
+ // ```
+ // If we want to fold %x into %b, the broadcasted dimensions from %x to
+ // %b has to be the same as that of from %s to %b.
+ ArrayRef<int64_t> shapecastShape =
+ srcShapeCast.getResultVectorType().getShape();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ ArrayRef<int64_t> destShape = destType.getShape();
+ BitVector origBroadcastedDims = getBroadcastedDims(shapecastShape, destShape);
+ BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape);
+ if (newBroadcastedDims != origBroadcastedDims)
+ return failure();
+
+ broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
+ return success();
+}
+
OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (getSourceType() == getResultVectorType())
return getSource();
+ if (succeeded(foldBroadcastOfShapeCast(*this)))
+ return getResult();
+
if (!adaptor.getSource())
return {};
auto vectorType = getResultVectorType();
@@ -2881,67 +2931,13 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
-
-// Return the broadcasted dimensions. Including broadcasts in the leading
-// dimensions and broadcasts through unit dimension (i.e. dim-1).
-static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
- ArrayRef<int64_t> destShape) {
- assert(destShape.size() >= srcShape.size());
- BitVector broadcastedDims(destShape.size());
- broadcastedDims.set(0, destShape.size() - srcShape.size());
- auto unitDims = computeBroadcastedUnitDims(srcShape, destShape);
- for (int64_t dim : unitDims)
- broadcastedDims.set(dim);
- return broadcastedDims;
-}
-
-// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
-// with broadcast's result type and the broadcasted dimensions are the same.
-struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
- PatternRewriter &rewriter) const override {
- auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
- if (!srcShapeCast)
- return failure();
-
- VectorType srcType = srcShapeCast.getSourceVectorType();
- VectorType destType = broadcastOp.getResultVectorType();
- // Check type compatibility.
- if (vector::isBroadcastableTo(srcType, destType) !=
- BroadcastableToResult::Success)
- return failure();
-
- // Given
- // ```
- // %s = shape_cast(%x)
- // %b = broadcast(%s)
- // ```
- // If we want to fold %x into %b, the broadcasted dimensions from %x to
- // %b has to be the same as that of from %s to %b.
- ArrayRef<int64_t> shapecastShape =
- srcShapeCast.getResultVectorType().getShape();
- ArrayRef<int64_t> srcShape = srcType.getShape();
- ArrayRef<int64_t> destShape = destType.getShape();
- BitVector origBroadcastedDims =
- getBroadcastedDims(shapecastShape, destShape);
- BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape);
- if (newBroadcastedDims != origBroadcastedDims)
- return failure();
-
- rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
- srcShapeCast.getSource());
- return success();
- }
-};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
+ results.add<BroadcastFolder>(context);
}
//===----------------------------------------------------------------------===//
>From e370b81aa9830798c1b968b164fafcb8e61a77eb Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Thu, 7 Aug 2025 16:26:10 -0700
Subject: [PATCH 7/8] fixup! Simplify the algorithm for the legality check
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 29 ++++++++++++------------
1 file changed, 15 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index abdbe7581487e..1d49442775fb8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2842,7 +2842,7 @@ LogicalResult BroadcastOp::verify() {
}
// Return the broadcasted dimensions. Including broadcasts in the leading
-// dimensions and broadcasts through unit dimension (i.e. dim-1).
+// dimensions and broadcasts through unit dimension.
static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> destShape) {
assert(destShape.size() >= srcShape.size());
@@ -2855,7 +2855,8 @@ static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
}
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
-// with broadcast's result type and the broadcasted dimensions are the same.
+// with broadcast's result type and shape_cast only adds or removes ones in the
+// leading dimensions.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
if (!srcShapeCast)
@@ -2868,22 +2869,22 @@ static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
BroadcastableToResult::Success)
return failure();
- // Given
- // ```
- // %s = shape_cast(%x)
- // %b = broadcast(%s)
- // ```
- // If we want to fold %x into %b, the broadcasted dimensions from %x to
- // %b has to be the same as that of from %s to %b.
+ ArrayRef<int64_t> srcShape = srcType.getShape();
ArrayRef<int64_t> shapecastShape =
srcShapeCast.getResultVectorType().getShape();
- ArrayRef<int64_t> srcShape = srcType.getShape();
- ArrayRef<int64_t> destShape = destType.getShape();
- BitVector origBroadcastedDims = getBroadcastedDims(shapecastShape, destShape);
- BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape);
- if (newBroadcastedDims != origBroadcastedDims)
+ // Trailing dimensions should be the same if shape_cast only alters the
+ // leading dimensions.
+ unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
+ if (!llvm::equal(srcShape.take_back(numTrailingDims),
+ shapecastShape.take_back(numTrailingDims)))
return failure();
+ assert(all_of(srcShape.drop_back(numTrailingDims),
+ [](int64_t E) { return E == 1; }) &&
+ all_of(shapecastShape.drop_back(numTrailingDims),
+ [](int64_t E) { return E == 1; }) &&
+ "ill-formed shape_cast");
+
broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
return success();
}
>From 6755a75814fbe1002d2eb9e74ce8ee25340c3aed Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Thu, 7 Aug 2025 16:35:07 -0700
Subject: [PATCH 8/8] fixup! Add more test cases
---
mlir/test/Dialect/Vector/canonicalize.mlir | 71 +++++++++++++++++++++-
1 file changed, 69 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 7c19d5ea41bfb..4a7176e1f8d7d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1168,10 +1168,10 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
// -----
-// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim
// CHECK-NOT: vector.shape_cast
// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
-func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
+func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
%0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
%1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
return %1 : vector<32x2xf32>
@@ -1179,6 +1179,45 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>)
// -----
+// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x1xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x1xf32>
+// CHECK: return %[[VAL_0]] : vector<32x2x1xf32>
+// CHECK: }
+func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(%arg0 : vector<2x1xf32>) -> vector<32x2x1xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
+ %1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x1xf32>
+ return %1 : vector<32x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x4xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x4xf32>
+// CHECK: return %[[VAL_0]] : vector<32x2x4xf32>
+// CHECK: }
+func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(%arg0 : vector<2x1xf32>) -> vector<32x2x4xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
+ %1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x4xf32>
+ return %1 : vector<32x2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<32x2xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<1x2xf32> to vector<32x2xf32>
+// CHECK: return %[[VAL_0]] : vector<32x2xf32>
+// CHECK: }
+func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(%arg0 : vector<1x2xf32>) -> vector<32x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<1x2xf32> to vector<2xf32>
+ %1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
+ return %1 : vector<32x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32>
// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
@@ -1201,6 +1240,34 @@ func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%a
// -----
+// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x4xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x4xf32>
+// CHECK: return %[[VAL_1]] : vector<2x4xf32>
+// CHECK: }
+func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(%arg0 : vector<2xf32>) -> vector<2x4xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<2x1xf32>
+ %1 = vector.broadcast %0 : vector<2x1xf32> to vector<2x4xf32>
+ return %1 : vector<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2xf32>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2xf32> to vector<32x2xf32>
+// CHECK: return %[[VAL_1]] : vector<32x2xf32>
+// CHECK: }
+func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(%arg0 : vector<2x1xf32>) -> vector<32x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<2xf32>
+ %1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
+ return %1 : vector<32x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
More information about the Mlir-commits
mailing list