[Mlir-commits] [mlir] [MLIR][Vector] Add DropUnitDimFromBroadcastOp pattern (PR #92938)

Hugo Trachino llvmlistbot at llvm.org
Tue May 21 10:06:25 PDT 2024


https://github.com/nujaa created https://github.com/llvm/llvm-project/pull/92938

This MR is part of a list of MRs aiming to generalize `DropUnitDimFromElementwiseOps` for other ops.
This commit implements `DropUnitDimFromBroadcastOp` to target  `vector::BroadcastOp`. 

>From 6df637789eca71fa26e059f239e80005f9968d1c Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 16 May 2024 19:08:32 +0800
Subject: [PATCH] [MLIR][Vector] Enable DropUnitDimFromBroadcastOp

---
 .../Vector/Transforms/VectorTransforms.cpp    | 64 ++++++++++++++++++-
 .../Vector/vector-transfer-flatten.mlir       | 55 ++++++++++++++++
 2 files changed, 117 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f29eba90c3ceb..a8494eac3e5aa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1695,6 +1695,66 @@ struct DropUnitDimFromElementwiseOps final
   }
 };
 
+/// Drops unit non scalable dimensions inside a broadcastOp which are shared
+/// among source and result with shape_casts.
+/// The newly inserted shape_cast Ops fold (before Op) and then
+/// restore the unit dim after Op. Source type is required to be a vector.
+///
+/// Ex:
+/// ```
+///  %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
+///  %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
+/// ```
+///
+/// Gets converted to:
+///
+/// ```
+///  %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///  %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
+///  %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
+///    vector<1x3x1x4xf32>
+///  %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
+///    vector<1x3x4xf32>
+/// ```
+/// %cast_new and %cast can be folded away.
+struct DropUnitDimFromBroadcastOp final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    if (!srcVT)
+      return failure();
+    auto resVT = broadcastOp.getResultVectorType();
+    VectorType newSrcVT = srcVT;
+    VectorType newResVT = resVT;
+    auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
+    // Reversing allows us to remove dims from the back without keeping track of
+    // removed dimensions.
+    for (const auto &dim : llvm::enumerate(llvm::reverse(srcVT.getShape()))) {
+      if (dim.value() == 1 &&
+          !srcVT.getScalableDims()[srcVT.getRank() - dim.index() - 1] &&
+          !broadcastedUnitDims.contains(srcVT.getRank() - dim.index() - 1)) {
+        newSrcVT = VectorType::Builder(newSrcVT).dropDim(srcVT.getRank() -
+                                                         dim.index() - 1);
+        newResVT = VectorType::Builder(newResVT).dropDim(resVT.getRank() -
+                                                         dim.index() - 1);
+      }
+    }
+
+    if (newSrcVT == srcVT)
+      return failure();
+    auto loc = broadcastOp->getLoc();
+    auto newSource = rewriter.create<vector::ShapeCastOp>(
+        loc, newSrcVT, broadcastOp.getSource());
+    auto newOp = rewriter.create<vector::BroadcastOp>(loc, newResVT, newSource);
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVT,
+                                             newOp.getResult());
+    return success();
+  }
+};
+
 /// Pattern to eliminate redundant zero-constants added to reduction operands.
 /// It's enough for there to be one initial zero value, so we can eliminate the
 /// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1819,8 +1879,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
 
 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
-      patterns.getContext(), benefit);
+  patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimFromBroadcastOp,
+               ShapeCastOpFolder>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 788ae9ac044ed..f1fc443b9d4bd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -460,6 +460,61 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 
+// -----
+
+func.func @drop_broadcast_unit_dim(%arg0 : vector<1x[1]x3x1xf128>) -> vector<4x1x[1]x3x1xf128> {
+  %bc = vector.broadcast %arg0 : vector<1x[1]x3x1xf128> to vector<4x1x[1]x3x1xf128>
+  return %bc : vector<4x1x[1]x3x1xf128>
+}
+
+// CHECK-LABEL:   func.func @drop_broadcast_unit_dim(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<1x[1]x3x1xf128>{{.*}}-> vector<4x1x[1]x3x1xf128> {
+// CHECK:           %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x[1]x3x1xf128> to vector<[1]x3xf128>
+// CHECK:           %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<[1]x3xf128> to vector<4x[1]x3xf128>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<4x[1]x3xf128> to vector<4x1x[1]x3x1xf128>
+// CHECK:           return %[[VAL_3]] : vector<4x1x[1]x3x1xf128>
+
+// -----
+
+func.func @drop_broadcasted_only_unit_dim(%arg0 : vector<1xf32>) -> vector<1x1xf32> {
+  %bc = vector.broadcast %arg0 : vector<1xf32> to vector<1x1xf32>
+  return %bc : vector<1x1xf32>
+}
+
+// CHECK-LABEL:   func.func @drop_broadcasted_only_unit_dim(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<1xf32>) -> vector<1x1xf32> {
+// CHECK:           %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32>
+// CHECK:           %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<f32> to vector<1xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] :  vector<1xf32> to vector<1x1xf32>
+// CHECK:           return %[[VAL_3]] : vector<1x1xf32>
+
+// -----
+
+// Generated unit dimensions through broadcasts are not dropped as we prefer to have a 
+// single broadcast rather than a broadcast and a shape_cast.
+func.func @drop_broadcast_generated_unit_dim(%arg0 : vector<4xf32>) -> vector<3x1x4xf32> {
+  %bc = vector.broadcast %arg0 : vector<4xf32> to vector<3x1x4xf32>
+  return %bc : vector<3x1x4xf32>
+}
+
+// CHECK-LABEL:   func.func @drop_broadcast_generated_unit_dim(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<4xf32>{{.*}}-> vector<3x1x4xf32> {
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<4xf32> to vector<3x1x4xf32>
+// CHECK:           return %[[VAL_1]] : vector<3x1x4xf32>
+
+// -----
+
+// A broadcasted unit dimension cannot be dropped to prevent type mismatch.
+func.func @drop_broadcasted_unit_dim(%arg0 : vector<2x1x4xf32>) -> vector<2x3x4xf32> {
+  %bc = vector.broadcast %arg0 : vector<2x1x4xf32> to vector<2x3x4xf32>
+  return %bc : vector<2x3x4xf32>
+}
+// CHECK-LABEL:   func.func @drop_broadcasted_unit_dim(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<2x1x4xf32>{{.*}}-> vector<2x3x4xf32> {
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x4xf32> to vector<2x3x4xf32>
+// CHECK:           return %[[VAL_1]] : vector<2x3x4xf32>
+
+
 // -----
 
 func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,



More information about the Mlir-commits mailing list