[Mlir-commits] [mlir] [mlir][vector] shape_cast(broadcast) -> broadcast canonicalization (PR #134939)
James Newling
llvmlistbot at llvm.org
Wed Apr 9 13:53:16 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/134939
>From 98daa18500210aedea35951e99c38add5fb5cd8c Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 8 Apr 2025 15:45:48 -0700
Subject: [PATCH 1/3] cover additional cases of shape_cast(broadcast) ->
broadcast canonicalization
Signed-off-by: James Newling <james.newling at gmail.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 33 +++++++++-------------
mlir/test/Dialect/Vector/canonicalize.mlir | 25 ++++++++++++++++
2 files changed, 39 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..c6d8ec1e1cf69 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5778,8 +5778,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
/// This only applies when the shape of the broadcast source
-/// 1. is a suffix of the shape of the result (i.e. when broadcast without
-/// reshape is expressive enough to capture the result in a single op), or
+/// 1. can be broadcast directly to the final shape, or
/// 2. has the same element count as the shape cast result.
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
@@ -5792,24 +5791,20 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
if (!broadcastOp)
return failure();
- ArrayRef<int64_t> broadcastSourceShape;
- if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
- broadcastSourceShape = srcType.getShape();
- ArrayRef<int64_t> shapeCastTargetShape =
- shapeCastOp.getResultVectorType().getShape();
-
- // If `broadcastSourceShape` is a suffix of the result, we can just replace
- // with a broadcast to the final shape.
- if (broadcastSourceShape ==
- shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- shapeCastOp, shapeCastOp.getResultVectorType(),
- broadcastOp.getSource());
- return success();
+ {
+ VectorType dstType = shapeCastOp.getResultVectorType();
+ auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ bool isScalar = !srcType;
+ if (isScalar || isBroadcastableTo(srcType, dstType) ==
+ BroadcastableToResult::Success) {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ shapeCastOp, dstType, broadcastOp.getSource());
+ return success();
+ }
}
- // Otherwise, if the final result has the same element count, we can replace
- // with a shape cast.
+ // If the final result has the same element count, we can replace with a
+ // shape cast.
if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
if (srcType.getNumElements() ==
shapeCastOp.getResultVectorType().getNumElements()) {
@@ -6079,7 +6074,7 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
}
};
-// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
+// Folds transpose(broadcast(<scalar>)) into broadcast(<scalar>).
struct FoldTransposedScalarBroadcast final
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..d7617d79b5cbf 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1017,6 +1017,31 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
// -----
+
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
+// CHECK: vector.broadcast
+// CHECK-SAME: f32 to vector<3x4x1xf32>
+// CHECK-NOT: vector.shape_cast
+func.func @canonicalize_broadcast_shapecast_scalar(%arg0: f32) -> vector<3x4x1xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<12xf32>
+ %1 = vector.shape_cast %0 : vector<12xf32> to vector<3x4x1xf32>
+ return %1 : vector<3x4x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_ones
+// CHECK: vector.broadcast
+// CHECK-SAME: vector<1x1xi8> to vector<1x1x6x1x4xi8>
+// CHECK-NOT: vector.shape_cast
+func.func @canonicalize_broadcast_shapecast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8>
+ %1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8>
+ return %1 : vector<1x1x6x1x4xi8>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
>From b58b83729c1210e706263d24ef681623b2f152ea Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 9 Apr 2025 13:47:43 -0700
Subject: [PATCH 2/3] address review comments
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 50 +++++++++++++---------
mlir/test/Dialect/Vector/canonicalize.mlir | 49 ++++++++++++---------
2 files changed, 59 insertions(+), 40 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c6d8ec1e1cf69..5b9a5d53b2ae0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5776,10 +5776,12 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
};
-/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
-/// This only applies when the shape of the broadcast source
-/// 1. can be broadcast directly to the final shape, or
-/// 2. has the same element count as the shape cast result.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
+///
+/// 1) Y = ShapeCast(X), or
+/// 2) Y = Broadcast(X)
+///
+/// If both (1) and (2) are possible, (1) is chosen.
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -5791,22 +5793,17 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
if (!broadcastOp)
return failure();
- {
- VectorType dstType = shapeCastOp.getResultVectorType();
- auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
- bool isScalar = !srcType;
- if (isScalar || isBroadcastableTo(srcType, dstType) ==
- BroadcastableToResult::Success) {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- shapeCastOp, dstType, broadcastOp.getSource());
- return success();
- }
- }
-
- // If the final result has the same element count, we can replace with a
- // shape cast.
- if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
- if (srcType.getNumElements() ==
+ auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ bool srcIsScalar = !srcVectorType;
+
+ // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
+ // Example:
+ // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
+ // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
+ // to
+ // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
+ if (srcVectorType) {
+ if (srcVectorType.getNumElements() ==
shapeCastOp.getResultVectorType().getNumElements()) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
shapeCastOp, shapeCastOp.getResultVectorType(),
@@ -5815,6 +5812,19 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
}
}
+ // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
+ // Example
+ // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
+ // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
+ // to
+ // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
+ VectorType dstVectorType = shapeCastOp.getResultVectorType();
+ if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
+ BroadcastableToResult::Success) {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ shapeCastOp, dstVectorType, broadcastOp.getSource());
+ return success();
+ }
return failure();
}
};
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d7617d79b5cbf..f4d1eccef2514 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1006,23 +1006,21 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast(%arg0: vector<3xf32>) -
// -----
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
-// CHECK-NOT: vector.broadcast
-// CHECK: vector.shape_cast {{.+}} : vector<3x4xf32> to vector<1x12xf32>
-func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>) -> vector<1x12xf32> {
- %0 = vector.broadcast %arg0 : vector<3x4xf32> to vector<1x1x3x4xf32>
- %1 = vector.shape_cast %0 : vector<1x1x3x4xf32> to vector<1x12xf32>
- return %1 : vector<1x12xf32>
+// CHECK-LABEL: func @canonicalize_broadcast_ones_shapecast_to_broadcast_ones
+// CHECK: vector.broadcast {{.*}} vector<1x1xi8> to vector<1x1x6x1x4xi8>
+// CHECK-NOT: vector.shape_cast
+func.func @canonicalize_broadcast_ones_shapecast_to_broadcast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8>
+ %1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8>
+ return %1 : vector<1x1x6x1x4xi8>
}
// -----
-
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
-// CHECK: vector.broadcast
-// CHECK-SAME: f32 to vector<3x4x1xf32>
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast_scalar
+// CHECK: vector.broadcast {{.*}} f32 to vector<3x4x1xf32>
// CHECK-NOT: vector.shape_cast
-func.func @canonicalize_broadcast_shapecast_scalar(%arg0: f32) -> vector<3x4x1xf32> {
+func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> vector<3x4x1xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<12xf32>
%1 = vector.shape_cast %0 : vector<12xf32> to vector<3x4x1xf32>
return %1 : vector<3x4x1xf32>
@@ -1030,14 +1028,25 @@ func.func @canonicalize_broadcast_shapecast_scalar(%arg0: f32) -> vector<3x4x1xf
// -----
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_ones
-// CHECK: vector.broadcast
-// CHECK-SAME: vector<1x1xi8> to vector<1x1x6x1x4xi8>
-// CHECK-NOT: vector.shape_cast
-func.func @canonicalize_broadcast_shapecast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
- %0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8>
- %1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8>
- return %1 : vector<1x1x6x1x4xi8>
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
+// CHECK-NOT: vector.broadcast
+// CHECK: vector.shape_cast {{.+}} : vector<3x4xf32> to vector<1x12xf32>
+func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>) -> vector<1x12xf32> {
+ %0 = vector.broadcast %arg0 : vector<3x4xf32> to vector<1x1x3x4xf32>
+ %1 = vector.shape_cast %0 : vector<1x1x3x4xf32> to vector<1x12xf32>
+ return %1 : vector<1x12xf32>
+}
+
+// -----
+
+// In this test, it could be folded to broadcast or shape_cast, shape_cast is chosen.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast_priority
+// CHECK-NOT: vector.broadcast
+// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
+func.func @canonicalize_broadcast_shapecast_to_shapcast_priority(%arg0 : vector<1xf32>) -> vector<1x1xf32> {
+ %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
+ %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
+ return %1 : vector<1x1xf32>
}
// -----
>From 151bbfd5bb84eddba948543a4c1a3c5b6dc36308 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 9 Apr 2025 13:57:51 -0700
Subject: [PATCH 3/3] spacing/naming improvements
Signed-off-by: James Newling <james.newling at gmail.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 +++-----
mlir/test/Dialect/Vector/canonicalize.mlir | 4 ++--
2 files changed, 5 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5b9a5d53b2ae0..59a7ea761a5ce 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5777,11 +5777,9 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
};
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-///
-/// 1) Y = ShapeCast(X), or
-/// 2) Y = Broadcast(X)
-///
-/// If both (1) and (2) are possible, (1) is chosen.
+/// i) Y = ShapeCast(X), or
+/// ii) Y = Broadcast(X)
+/// If both (i) and (ii) are possible, (i) is chosen.
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f4d1eccef2514..8f90094236ddb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1006,10 +1006,10 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast(%arg0: vector<3xf32>) -
// -----
-// CHECK-LABEL: func @canonicalize_broadcast_ones_shapecast_to_broadcast_ones
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast_ones
// CHECK: vector.broadcast {{.*}} vector<1x1xi8> to vector<1x1x6x1x4xi8>
// CHECK-NOT: vector.shape_cast
-func.func @canonicalize_broadcast_ones_shapecast_to_broadcast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
+func.func @canonicalize_broadcast_shapecast_to_broadcast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
%0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8>
%1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8>
return %1 : vector<1x1x6x1x4xi8>
More information about the Mlir-commits
mailing list