[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of shape_cast (PR #150523)
Min-Yih Hsu
llvmlistbot at llvm.org
Fri Jul 25 09:08:15 PDT 2025
https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/150523
>From 9ca07a1022b7421e740390dff3e5aa2046a24e61 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Thu, 24 Jul 2025 13:55:56 -0700
Subject: [PATCH 1/3] [mlir][vector] Canonicalize broadcast of shape_cast
Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is
compatible with broadcast's result type.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 24 +++++++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 22 ++++++++++++++++++++
2 files changed, 45 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c97aed6e7742..ad908319d8584 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2938,13 +2938,35 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type.
+struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ if (auto srcShapeCast =
+ broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ if (vector::isBroadcastableTo(srcType, destType) ==
+ BroadcastableToResult::Success) {
+ rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+ srcShapeCast.getSource());
+ return success();
+ }
+ }
+ return failure();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder>(context);
+ results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1461c30162c5f..0fd2acd06c8ec 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1168,6 +1168,28 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
// -----
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
+// CHECK-NOT: vector.shape_cast
+// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
+func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
+ %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
+ return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
+// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
+// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
+func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+ %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
+ %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
+ return %1 : vector<2x4x16xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
>From 10a914efacadd06d8dc40c266c1a85416d546782 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min.hsu at sifive.com>
Date: Fri, 25 Jul 2025 09:06:35 -0700
Subject: [PATCH 2/3] fixup! Address review comments
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 25 ++++++++++++------------
1 file changed, 13 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ad908319d8584..348c713980ef6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2946,18 +2946,19 @@ struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
- if (auto srcShapeCast =
- broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
- VectorType srcType = srcShapeCast.getSourceVectorType();
- VectorType destType = broadcastOp.getResultVectorType();
- if (vector::isBroadcastableTo(srcType, destType) ==
- BroadcastableToResult::Success) {
- rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
- srcShapeCast.getSource());
- return success();
- }
- }
- return failure();
+ auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+ if (!srcShapeCast)
+ return failure();
+
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ if (vector::isBroadcastableTo(srcType, destType) !=
+ BroadcastableToResult::Success)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+ srcShapeCast.getSource());
+ return success();
}
};
} // namespace
>From 067f1150c3b6ea87cd9b09f64949b92d22087c28 Mon Sep 17 00:00:00 2001
From: Min-Yih Hsu <min at myhsu.dev>
Date: Fri, 25 Jul 2025 09:08:06 -0700
Subject: [PATCH 3/3] fixup! Update mlir/test/Dialect/Vector/canonicalize.mlir
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Andrzej WarzyĆski <andrzej.warzynski at gmail.com>
---
mlir/test/Dialect/Vector/canonicalize.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0fd2acd06c8ec..fc4ef6bf39379 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1182,7 +1182,7 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>)
// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
-func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
%0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
%1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
return %1 : vector<2x4x16xf32>
More information about the Mlir-commits
mailing list