[Mlir-commits] [mlir] 9f37c21 - [mlir][vector] Fold shape_cast of broadcast with same element count

Lei Zhang llvmlistbot at llvm.org
Tue Aug 15 12:06:19 PDT 2023


Author: Lei Zhang
Date: 2023-08-15T11:26:15-07:00
New Revision: 9f37c21349580d88858a463df88235923a9bf7e0

URL: https://github.com/llvm/llvm-project/commit/9f37c21349580d88858a463df88235923a9bf7e0
DIFF: https://github.com/llvm/llvm-project/commit/9f37c21349580d88858a463df88235923a9bf7e0.diff

LOG: [mlir][vector] Fold shape_cast of broadcast with same element count

For such cases we can generate a shape cast to simplify the IR.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D157929

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0d4f8952244f9d..f5a7cdc556b515 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4757,9 +4757,10 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
 };
 
 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
-/// This only applies when the shape of the broadcast source is a suffix of the
-/// shape of the result (i.e. when broadcast without reshape is expressive
-/// enough to capture the result in a single op).
+/// This only applies when the shape of the broadcast source
+/// 1. is a suffix of the shape of the result (i.e. when broadcast without
+///    reshape is expressive enough to capture the result in a single op), or
+/// 2. has the same element count as the shape cast result.
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -4771,23 +4772,35 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     if (!broadcastOp)
       return failure();
 
-    auto broadcastSourceVectorType =
-        llvm::dyn_cast<VectorType>(broadcastOp.getSourceType());
-    auto broadcastSourceShape = broadcastSourceVectorType
-                                    ? broadcastSourceVectorType.getShape()
-                                    : ArrayRef<int64_t>{};
-    auto shapeCastTargetShape = shapeCastOp.getResultVectorType().getShape();
-
-    // Bail if `broadcastSourceShape` is not a suffix of the result.
-    bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back(
-                                                 broadcastSourceShape.size()));
-    if (!isSuffix)
-      return failure();
+    ArrayRef<int64_t> broadcastSourceShape;
+    if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
+      broadcastSourceShape = srcType.getShape();
+    ArrayRef<int64_t> shapeCastTargetShape =
+        shapeCastOp.getResultVectorType().getShape();
 
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-        shapeCastOp, shapeCastOp.getResultVectorType(),
-        broadcastOp.getSource());
-    return success();
+    // If `broadcastSourceShape` is a suffix of the result, we can just replace
+    // with a broadcast to the final shape.
+    if (broadcastSourceShape ==
+        shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+          shapeCastOp, shapeCastOp.getResultVectorType(),
+          broadcastOp.getSource());
+      return success();
+    }
+
+    // Otherwise, if the final result has the same element count, we can replace
+    // with a shape cast.
+    if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
+      if (srcType.getNumElements() ==
+          shapeCastOp.getResultVectorType().getNumElements()) {
+        rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+            shapeCastOp, shapeCastOp.getResultVectorType(),
+            broadcastOp.getSource());
+        return success();
+      }
+    }
+
+    return failure();
   }
 };
 

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index be266bbc6c9ac8..2f76fc5d5ebdb2 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -714,10 +714,10 @@ func.func @dont_fold_broadcast_shapecast_
diff _shape(%arg0: vector<4xf32>) -> vec
 
 // -----
 
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast
 //       CHECK:   vector.broadcast
 //   CHECK-NOT:   vector.shape_cast
-func.func @canonicalize_broadcast_shapecast(%arg0: vector<3xf32>) -> vector<8x3xf32> {
+func.func @canonicalize_broadcast_shapecast_to_broadcast(%arg0: vector<3xf32>) -> vector<8x3xf32> {
     %0 = vector.broadcast %arg0 : vector<3xf32> to vector<2x4x3xf32>
     %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
     return %1 : vector<8x3xf32>
@@ -725,6 +725,17 @@ func.func @canonicalize_broadcast_shapecast(%arg0: vector<3xf32>) -> vector<8x3x
 
 // -----
 
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
+//   CHECK-NOT:   vector.broadcast
+//       CHECK:   vector.shape_cast {{.+}} : vector<3x4xf32> to vector<1x12xf32>
+func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>) -> vector<1x12xf32> {
+    %0 = vector.broadcast %arg0 : vector<3x4xf32> to vector<1x1x3x4xf32>
+    %1 = vector.shape_cast %0 : vector<1x1x3x4xf32> to vector<1x12xf32>
+    return %1 : vector<1x12xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfers
 func.func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list