[Mlir-commits] [mlir] 9f37c21 - [mlir][vector] Fold shape_cast of broadcast with same element count
Lei Zhang
llvmlistbot at llvm.org
Tue Aug 15 12:06:19 PDT 2023
Author: Lei Zhang
Date: 2023-08-15T11:26:15-07:00
New Revision: 9f37c21349580d88858a463df88235923a9bf7e0
URL: https://github.com/llvm/llvm-project/commit/9f37c21349580d88858a463df88235923a9bf7e0
DIFF: https://github.com/llvm/llvm-project/commit/9f37c21349580d88858a463df88235923a9bf7e0.diff
LOG: [mlir][vector] Fold shape_cast of broadcast with same element count
For such cases we can generate a shape cast to simplify the IR.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D157929
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0d4f8952244f9d..f5a7cdc556b515 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4757,9 +4757,10 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
};
/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
-/// This only applies when the shape of the broadcast source is a suffix of the
-/// shape of the result (i.e. when broadcast without reshape is expressive
-/// enough to capture the result in a single op).
+/// This only applies when the shape of the broadcast source
+/// 1. is a suffix of the shape of the result (i.e. when broadcast without
+/// reshape is expressive enough to capture the result in a single op), or
+/// 2. has the same element count as the shape cast result.
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -4771,23 +4772,35 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
if (!broadcastOp)
return failure();
- auto broadcastSourceVectorType =
- llvm::dyn_cast<VectorType>(broadcastOp.getSourceType());
- auto broadcastSourceShape = broadcastSourceVectorType
- ? broadcastSourceVectorType.getShape()
- : ArrayRef<int64_t>{};
- auto shapeCastTargetShape = shapeCastOp.getResultVectorType().getShape();
-
- // Bail if `broadcastSourceShape` is not a suffix of the result.
- bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back(
- broadcastSourceShape.size()));
- if (!isSuffix)
- return failure();
+ ArrayRef<int64_t> broadcastSourceShape;
+ if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
+ broadcastSourceShape = srcType.getShape();
+ ArrayRef<int64_t> shapeCastTargetShape =
+ shapeCastOp.getResultVectorType().getShape();
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- shapeCastOp, shapeCastOp.getResultVectorType(),
- broadcastOp.getSource());
- return success();
+ // If `broadcastSourceShape` is a suffix of the result, we can just replace
+ // with a broadcast to the final shape.
+ if (broadcastSourceShape ==
+ shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ shapeCastOp, shapeCastOp.getResultVectorType(),
+ broadcastOp.getSource());
+ return success();
+ }
+
+ // Otherwise, if the final result has the same element count, we can replace
+ // with a shape cast.
+ if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
+ if (srcType.getNumElements() ==
+ shapeCastOp.getResultVectorType().getNumElements()) {
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ shapeCastOp, shapeCastOp.getResultVectorType(),
+ broadcastOp.getSource());
+ return success();
+ }
+ }
+
+ return failure();
}
};
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index be266bbc6c9ac8..2f76fc5d5ebdb2 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -714,10 +714,10 @@ func.func @dont_fold_broadcast_shapecast_
diff _shape(%arg0: vector<4xf32>) -> vec
// -----
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast
// CHECK: vector.broadcast
// CHECK-NOT: vector.shape_cast
-func.func @canonicalize_broadcast_shapecast(%arg0: vector<3xf32>) -> vector<8x3xf32> {
+func.func @canonicalize_broadcast_shapecast_to_broadcast(%arg0: vector<3xf32>) -> vector<8x3xf32> {
%0 = vector.broadcast %arg0 : vector<3xf32> to vector<2x4x3xf32>
%1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
return %1 : vector<8x3xf32>
@@ -725,6 +725,17 @@ func.func @canonicalize_broadcast_shapecast(%arg0: vector<3xf32>) -> vector<8x3x
// -----
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
+// CHECK-NOT: vector.broadcast
+// CHECK: vector.shape_cast {{.+}} : vector<3x4xf32> to vector<1x12xf32>
+func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>) -> vector<1x12xf32> {
+ %0 = vector.broadcast %arg0 : vector<3x4xf32> to vector<1x1x3x4xf32>
+ %1 = vector.shape_cast %0 : vector<1x1x3x4xf32> to vector<1x12xf32>
+ return %1 : vector<1x12xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfers
func.func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list