[Mlir-commits] [mlir] [MLIR][Vector] Add DropUnitDimFromBroadcastOp pattern (PR #92938)

Hugo Trachino llvmlistbot at llvm.org
Thu Jun 20 06:06:26 PDT 2024


https://github.com/nujaa updated https://github.com/llvm/llvm-project/pull/92938

>From d94eb3d81c6e53b90c022ee81c92ac5fac9be79e Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 16 May 2024 19:08:32 +0800
Subject: [PATCH 1/2] [MLIR][Vector] Enable DropUnitDimFromBroadcastOp

---
 .../Vector/Transforms/VectorTransforms.cpp    | 64 ++++++++++++++++++-
 .../Vector/vector-transfer-flatten.mlir       | 54 ++++++++++++++++
 2 files changed, 116 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 200517913677f..5d8b77621e0d7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1703,6 +1703,66 @@ struct DropUnitDimFromElementwiseOps final
   }
 };
 
+/// Drops unit non scalable dimensions inside a broadcastOp which are shared
+/// among source and result with shape_casts.
+/// The newly inserted shape_cast Ops fold (before Op) and then
+/// restore the unit dim after Op. Source type is required to be a vector.
+///
+/// Ex:
+/// ```
+///  %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
+///  %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
+/// ```
+///
+/// Gets converted to:
+///
+/// ```
+///  %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///  %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
+///  %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
+///    vector<1x3x1x4xf32>
+///  %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
+///    vector<1x3x4xf32>
+/// ```
+/// %cast_new and %cast can be folded away.
+struct DropUnitDimFromBroadcastOp final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    if (!srcVT)
+      return failure();
+    auto resVT = broadcastOp.getResultVectorType();
+    VectorType newSrcVT = srcVT;
+    VectorType newResVT = resVT;
+    auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
+    // Reversing allows us to remove dims from the back without keeping track of
+    // removed dimensions.
+    for (const auto &dim : llvm::enumerate(llvm::reverse(srcVT.getShape()))) {
+      if (dim.value() == 1 &&
+          !srcVT.getScalableDims()[srcVT.getRank() - dim.index() - 1] &&
+          !broadcastedUnitDims.contains(srcVT.getRank() - dim.index() - 1)) {
+        newSrcVT = VectorType::Builder(newSrcVT).dropDim(srcVT.getRank() -
+                                                         dim.index() - 1);
+        newResVT = VectorType::Builder(newResVT).dropDim(resVT.getRank() -
+                                                         dim.index() - 1);
+      }
+    }
+
+    if (newSrcVT == srcVT)
+      return failure();
+    auto loc = broadcastOp->getLoc();
+    auto newSource = rewriter.create<vector::ShapeCastOp>(
+        loc, newSrcVT, broadcastOp.getSource());
+    auto newOp = rewriter.create<vector::BroadcastOp>(loc, newResVT, newSource);
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVT,
+                                             newOp.getResult());
+    return success();
+  }
+};
+
 /// Pattern to eliminate redundant zero-constants added to reduction operands.
 /// It's enough for there to be one initial zero value, so we can eliminate the
 /// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1827,8 +1887,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
 
 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
-      patterns.getContext(), benefit);
+  patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimFromBroadcastOp,
+               ShapeCastOpFolder>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 42bf7201daaa7..bbeffa979e9c3 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -535,6 +535,60 @@ func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
 
 // -----
 
+func.func @drop_broadcast_unit_dim(%arg0 : vector<1x[1]x3x1xf128>) -> vector<4x1x[1]x3x1xf128> {
+  %bc = vector.broadcast %arg0 : vector<1x[1]x3x1xf128> to vector<4x1x[1]x3x1xf128>
+  return %bc : vector<4x1x[1]x3x1xf128>
+}
+
+// CHECK-LABEL:   func.func @drop_broadcast_unit_dim(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<1x[1]x3x1xf128>{{.*}}-> vector<4x1x[1]x3x1xf128> {
+// CHECK:           %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x[1]x3x1xf128> to vector<[1]x3xf128>
+// CHECK:           %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<[1]x3xf128> to vector<4x[1]x3xf128>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<4x[1]x3xf128> to vector<4x1x[1]x3x1xf128>
+// CHECK:           return %[[VAL_3]] : vector<4x1x[1]x3x1xf128>
+
+// -----
+
+func.func @drop_broadcasted_only_unit_dim(%arg0 : vector<1xf32>) -> vector<1x1xf32> {
+  %bc = vector.broadcast %arg0 : vector<1xf32> to vector<1x1xf32>
+  return %bc : vector<1x1xf32>
+}
+
+// CHECK-LABEL:   func.func @drop_broadcasted_only_unit_dim(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<1xf32>) -> vector<1x1xf32> {
+// CHECK:           %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32>
+// CHECK:           %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<f32> to vector<1xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] :  vector<1xf32> to vector<1x1xf32>
+// CHECK:           return %[[VAL_3]] : vector<1x1xf32>
+
+// -----
+
+// Generated unit dimensions through broadcasts are not dropped as we prefer to have a 
+// single broadcast rather than a broadcast and a shape_cast.
+func.func @drop_broadcast_generated_unit_dim(%arg0 : vector<4xf32>) -> vector<3x1x4xf32> {
+  %bc = vector.broadcast %arg0 : vector<4xf32> to vector<3x1x4xf32>
+  return %bc : vector<3x1x4xf32>
+}
+
+// CHECK-LABEL:   func.func @drop_broadcast_generated_unit_dim(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<4xf32>{{.*}}-> vector<3x1x4xf32> {
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<4xf32> to vector<3x1x4xf32>
+// CHECK:           return %[[VAL_1]] : vector<3x1x4xf32>
+
+// -----
+
+// A broadcasted unit dimension cannot be dropped to prevent type mismatch.
+func.func @drop_broadcasted_unit_dim(%arg0 : vector<2x1x4xf32>) -> vector<2x3x4xf32> {
+  %bc = vector.broadcast %arg0 : vector<2x1x4xf32> to vector<2x3x4xf32>
+  return %bc : vector<2x3x4xf32>
+}
+// CHECK-LABEL:   func.func @drop_broadcasted_unit_dim(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<2x1x4xf32>{{.*}}-> vector<2x3x4xf32> {
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x4xf32> to vector<2x3x4xf32>
+// CHECK:           return %[[VAL_1]] : vector<2x3x4xf32>
+
+// -----
+
 func.func @negative_out_of_bound_transfer_read(
     %arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
   %c0 = arith.constant 0 : index

>From cebfd7435ab7b63a32cbbc8f428d504ab3028f4c Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 20 Jun 2024 18:28:18 +0800
Subject: [PATCH 2/2] Hoist out vector builder

---
 .../Vector/Transforms/VectorTransforms.cpp    | 32 +++++++++----------
 1 file changed, 16 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 5d8b77621e0d7..70349a61652ea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1731,33 +1731,33 @@ struct DropUnitDimFromBroadcastOp final
 
   LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
                                 PatternRewriter &rewriter) const override {
-    auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
-    if (!srcVT)
+    auto srcVecTy = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    if (!srcVecTy)
       return failure();
-    auto resVT = broadcastOp.getResultVectorType();
-    VectorType newSrcVT = srcVT;
-    VectorType newResVT = resVT;
+    auto resVecTy = broadcastOp.getResultVectorType();
+    auto srcVecTyBuilder = VectorType::Builder(srcVecTy);
+    auto resVecTyBuilder = VectorType::Builder(resVecTy);
     auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
     // Reversing allows us to remove dims from the back without keeping track of
     // removed dimensions.
-    for (const auto &dim : llvm::enumerate(llvm::reverse(srcVT.getShape()))) {
+    for (const auto &dim :
+         llvm::enumerate(llvm::reverse(srcVecTy.getShape()))) {
       if (dim.value() == 1 &&
-          !srcVT.getScalableDims()[srcVT.getRank() - dim.index() - 1] &&
-          !broadcastedUnitDims.contains(srcVT.getRank() - dim.index() - 1)) {
-        newSrcVT = VectorType::Builder(newSrcVT).dropDim(srcVT.getRank() -
-                                                         dim.index() - 1);
-        newResVT = VectorType::Builder(newResVT).dropDim(resVT.getRank() -
-                                                         dim.index() - 1);
+          !srcVecTy.getScalableDims()[srcVecTy.getRank() - dim.index() - 1] &&
+          !broadcastedUnitDims.contains(srcVecTy.getRank() - dim.index() - 1)) {
+        srcVecTyBuilder.dropDim(srcVecTy.getRank() - dim.index() - 1);
+        resVecTyBuilder.dropDim(resVecTy.getRank() - dim.index() - 1);
       }
     }
 
-    if (newSrcVT == srcVT)
+    if (VectorType(srcVecTyBuilder) == srcVecTy)
       return failure();
     auto loc = broadcastOp->getLoc();
     auto newSource = rewriter.create<vector::ShapeCastOp>(
-        loc, newSrcVT, broadcastOp.getSource());
-    auto newOp = rewriter.create<vector::BroadcastOp>(loc, newResVT, newSource);
-    rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVT,
+        loc, VectorType(srcVecTyBuilder), broadcastOp.getSource());
+    auto newOp = rewriter.create<vector::BroadcastOp>(
+        loc, VectorType(resVecTyBuilder), newSource);
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVecTy,
                                              newOp.getResult());
     return success();
   }



More information about the Mlir-commits mailing list