[Mlir-commits] [mlir] a48bdee - q[mlir][Vector] Add a ShapeCastOp(BroadcastOp) canonicalization pattern

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jun 28 09:49:41 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-28T09:49:37-07:00
New Revision: a48bdee6866c79275bcf650a57c6c19fdf63991e

URL: https://github.com/llvm/llvm-project/commit/a48bdee6866c79275bcf650a57c6c19fdf63991e
DIFF: https://github.com/llvm/llvm-project/commit/a48bdee6866c79275bcf650a57c6c19fdf63991e.diff

LOG: q[mlir][Vector] Add a ShapeCastOp(BroadcastOp) canonicalization pattern

This pattern can kick in when the source of the broadcast has a shape
that is a prefix/suffix of the result of the shape_cast.

Differential Revision: https://reviews.llvm.org/D128734

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8332c0b8b260b..cb4fe51d41d2e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4188,12 +4188,46 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
   }
 };
 
+/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
+/// This only applies when the shape of the broadcast source is a suffix of the
+/// shape of the result (i.e. when broadcast without reshape is expressive
+/// enough to capture the result in a single op).
+class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
+public:
+  using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto broadcastOp =
+        shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
+    if (!broadcastOp)
+      return failure();
+
+    auto broadcastSourceVectorType =
+        broadcastOp.getSourceType().dyn_cast<VectorType>();
+    auto broadcastSourceShape = broadcastSourceVectorType
+                                    ? broadcastSourceVectorType.getShape()
+                                    : ArrayRef<int64_t>{};
+    auto shapeCastTargetShape = shapeCastOp.getResultVectorType().getShape();
+
+    // Bail if `broadcastSourceShape` is not a suffix of the result.
+    bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back(
+                                                 broadcastSourceShape.size()));
+    if (!isSuffix)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        shapeCastOp, shapeCastOp.getResultVectorType(),
+        broadcastOp.getSource());
+    return success();
+  }
+};
+
 } // namespace
 
 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
-  results.add<ShapeCastConstantFolder>(context);
+  results.add<ShapeCastConstantFolder, ShapeCastBroadcastFolder>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e066e19eb9137..d16a6fa2c7e11 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -646,10 +646,10 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {
 
 // -----
 
-// CHECK-LABEL: func @dont_fold_broadcast_shapecast_scalar
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
 //       CHECK:   vector.broadcast
-//       CHECK:   vector.shape_cast
-func.func @dont_fold_broadcast_shapecast_scalar(%arg0: f32) -> vector<1xf32> {
+//   CHECK-NOT:   vector.shape_cast
+func.func @canonicalize_broadcast_shapecast_scalar(%arg0: f32) -> vector<1xf32> {
     %0 = vector.broadcast %arg0 : f32 to vector<1x1x1xf32>
     %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1xf32>
     return %1 : vector<1xf32>
@@ -668,6 +668,17 @@ func.func @dont_fold_broadcast_shapecast_
diff _shape(%arg0: vector<4xf32>) -> vec
 
 // -----
 
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast
+//       CHECK:   vector.broadcast
+//   CHECK-NOT:   vector.shape_cast
+func.func @canonicalize_broadcast_shapecast(%arg0: vector<3xf32>) -> vector<8x3xf32> {
+    %0 = vector.broadcast %arg0 : vector<3xf32> to vector<2x4x3xf32>
+    %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
+    return %1 : vector<8x3xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfers
 func.func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list