[Mlir-commits] [mlir] [mlir][linalg] Add a pattern to drop unit dim of `linalg.broadcast` (PR #106533)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 29 04:45:40 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR add a pattern to drop unit dim of `linalg.broadcast`. If the broadcasted shape size is 1, it can be droped.
e.g.
```
  %0 = linalg.broadcast ins(%input : tensor<4xf32>)
                        outs(%init : tensor<1x4xf32>)
                        dimensions = [0]
```
converted to:
```
  %collapsed = tensor.collapse_shape %init [[0, 1]] :
                  tensor<1x4xf32> into tensor<4xf32>
  %0 = linalg.broadcast ins(%input : tensor<4xf32>)
                        outs(%collapsed : tensor<4xf32>)
                        dimensions = []
  %expanded = tensor.expand_shape %0 [[0, 1]] output_shape [1, 4] :
                  tensor<4xf32> into tensor<1x4xf32>
```

---
Full diff: https://github.com/llvm/llvm-project/pull/106533.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+77-1) 
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+65-20) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76df3ecf2d2bd4..da45abc682f129 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2134,9 +2134,85 @@ void BroadcastOp::getEffects(
   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
+/// If the broadcasted shape size is 1, it can be droped.
+/// e.g.
+///    ```
+///    %0 = linalg.broadcast ins(%input : tensor<4xf32>)
+///                          outs(%init : tensor<1x4xf32>)
+///                          dimensions = [0]
+///    ```
+/// converted to:
+///    ```
+///    %collapsed = tensor.collapse_shape %init [[0, 1]] :
+///                    tensor<1x4xf32> into tensor<4xf32>
+///    %0 = linalg.broadcast ins(%input : tensor<4xf32>)
+///                          outs(%collapsed : tensor<4xf32>)
+///                          dimensions = []
+///    %expanded = tensor.expand_shape %0 [[0, 1]] output_shape [1, 4] :
+///                    tensor<4xf32> into tensor<1x4xf32>
+///    ```
+struct DropUnitDimOfBroadcastOp : OpRewritePattern<linalg::BroadcastOp> {
+  using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    if (!broadcastOp.hasPureTensorSemantics())
+      return failure();
+
+    auto init = broadcastOp.getInit();
+    auto initType = dyn_cast<RankedTensorType>(init.getType());
+    if (!initType)
+      return failure();
+    auto initShape = initType.getShape();
+    ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+    if (!llvm::any_of(dimensions,
+                      [&](auto dim) { return initShape[dim] == 1; }))
+      return failure();
+
+    SmallVector<int64_t> newDimensions;
+    int64_t dropDim = 0;
+    // Adjust dimensions of broadcast.
+    for (int64_t dim : dimensions) {
+      if (initShape[dim] != 1) {
+        newDimensions.push_back(dim - dropDim);
+      } else {
+        ++dropDim;
+      }
+    }
+    SmallVector<ReassociationIndices> reassociation;
+    // Build reassociation indices by grouping consecutive size-1 dimensions.
+    bool needCollapse = false;
+    for (int64_t dim = 0; dim < initType.getRank(); ++dim) {
+      if (needCollapse) {
+        reassociation.back().push_back(dim);
+      } else {
+        reassociation.push_back({dim});
+      }
+      // Update the needCollapse flag.
+      needCollapse =
+          (initShape[dim] == 1 && llvm::is_contained(dimensions, dim));
+    }
+
+    Location loc = broadcastOp.getLoc();
+    auto collapsedType =
+        tensor::CollapseShapeOp::inferCollapsedType(initType, reassociation);
+    auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
+        loc, collapsedType, init, reassociation);
+    auto newBroadcast =
+        rewriter
+            .create<linalg::BroadcastOp>(loc, broadcastOp.getInput(),
+                                         collapsedInit, newDimensions)
+            .getResult()[0];
+    rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+        broadcastOp, initType, newBroadcast, reassociation);
+    return success();
+  }
+};
+
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
+  results.add<DropUnitDimOfBroadcastOp, EraseIdentityLinalgOp<BroadcastOp>>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 4bc2ed140da91a..19900203f0f5ff 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1039,6 +1039,51 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
 
 // -----
 
+// CHECK-LABEL:   func.func @broadcast_unit_shape_front_fold(
+// CHECK-SAME:                                               %[[VAL_0:.*]]: tensor<2x3xf32>,
+// CHECK-SAME:                                               %[[VAL_1:.*]]: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
+// CHECK:           %[[VAL_2:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [1, 2, 3] : tensor<2x3xf32> into tensor<1x2x3xf32>
+// CHECK:           return %[[VAL_2]] : tensor<1x2x3xf32>
+// CHECK:         }
+func.func @broadcast_unit_shape_front_fold(%input: tensor<2x3xf32>, %init: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
+  %0 = linalg.broadcast ins(%input: tensor<2x3xf32>) outs(%init: tensor<1x2x3xf32>) dimensions = [0]
+  return %0 : tensor<1x2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_unit_shape_middle_fold(
+// CHECK-SAME:                                                %[[VAL_0:.*]]: tensor<2xf32>,
+// CHECK-SAME:                                                %[[VAL_1:.*]]: tensor<2x1x3xf32>) -> tensor<2x1x3xf32> {
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1, 2]] : tensor<2x1x3xf32> into tensor<2x3xf32>
+// CHECK:           %[[VAL_3:.*]] = linalg.broadcast ins(%[[VAL_0]] : tensor<2xf32>) outs(%[[VAL_2]] : tensor<2x3xf32>) dimensions = [1]
+// CHECK:           %[[VAL_4:.*]] = tensor.expand_shape %[[VAL_3]] {{\[\[}}0], [1, 2]] output_shape [2, 1, 3] : tensor<2x3xf32> into tensor<2x1x3xf32>
+// CHECK:           return %[[VAL_4]] : tensor<2x1x3xf32>
+// CHECK:         }
+func.func @broadcast_unit_shape_middle_fold(%input: tensor<2xf32>, %init: tensor<2x1x3xf32>) -> tensor<2x1x3xf32> {
+  %0 = linalg.broadcast ins(%input: tensor<2xf32>) outs(%init: tensor<2x1x3xf32>) dimensions = [1, 2]
+  return %0 : tensor<2x1x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_unit_shape_dynamic_fold(
+// CHECK-SAME:                                                 %[[VAL_0:.*]]: tensor<2x?xf32>,
+// CHECK-SAME:                                                 %[[VAL_1:.*]]: tensor<2x1x3x?xf32>) -> tensor<2x1x3x?xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1, 2], [3]] : tensor<2x1x3x?xf32> into tensor<2x3x?xf32>
+// CHECK:           %[[VAL_4:.*]] = linalg.broadcast ins(%[[VAL_0]] : tensor<2x?xf32>) outs(%[[VAL_3]] : tensor<2x3x?xf32>) dimensions = [1]
+// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_2]] : tensor<2x1x3x?xf32>
+// CHECK:           %[[VAL_6:.*]] = tensor.expand_shape %[[VAL_4]] {{\[\[}}0], [1, 2], [3]] output_shape [2, 1, 3, %[[VAL_5]]] : tensor<2x3x?xf32> into tensor<2x1x3x?xf32>
+// CHECK:           return %[[VAL_6]] : tensor<2x1x3x?xf32>
+// CHECK:         }
+func.func @broadcast_unit_shape_dynamic_fold(%input: tensor<2x?xf32>, %init: tensor<2x1x3x?xf32>) -> tensor<2x1x3x?xf32> {
+  %0 = linalg.broadcast ins(%input: tensor<2x?xf32>) outs(%init: tensor<2x1x3x?xf32>) dimensions = [1, 2]
+  return %0 : tensor<2x1x3x?xf32>
+}
+
+// -----
+
 func.func @transpose_1d(%input: tensor<16xf32>,
                         %init: tensor<16xf32>) -> tensor<16xf32> {
   %transpose = linalg.transpose
@@ -1119,53 +1164,53 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
 // -----
 
 func.func @broadcast_transpose_fold(%input: tensor<2x4x5xf32>,
-                                    %init1: tensor<1x2x3x4x5x6xf32>,
-                                    %init2: tensor<1x6x2x3x5x4xf32>) -> tensor<1x6x2x3x5x4xf32> {
+                                    %init1: tensor<7x2x3x4x5x6xf32>,
+                                    %init2: tensor<7x6x2x3x5x4xf32>) -> tensor<7x6x2x3x5x4xf32> {
   // CHECK-LABEL: @broadcast_transpose_fold
   //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32>
-  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32>
-  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x6x2x3x5x4xf32>
+  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<7x2x3x4x5x6xf32>
+  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<7x6x2x3x5x4xf32>
   //       CHECK:   %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32>
   //       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1]
-  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x6x2x3x5x4xf32>) dimensions = [0, 3, 1]
-  //       CHECK:   return %[[BROADCAST]] : tensor<1x6x2x3x5x4xf32>
+  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<7x6x2x3x5x4xf32>) dimensions = [0, 3, 1]
+  //       CHECK:   return %[[BROADCAST]] : tensor<7x6x2x3x5x4xf32>
   %broadcast = linalg.broadcast
       ins(%input : tensor<2x4x5xf32>)
-      outs(%init1 : tensor<1x2x3x4x5x6xf32>)
+      outs(%init1 : tensor<7x2x3x4x5x6xf32>)
       dimensions = [0, 2, 5]
   %transpose = linalg.transpose
-      ins(%broadcast : tensor<1x2x3x4x5x6xf32>)
-      outs(%init2 : tensor<1x6x2x3x5x4xf32>)
+      ins(%broadcast : tensor<7x2x3x4x5x6xf32>)
+      outs(%init2 : tensor<7x6x2x3x5x4xf32>)
       permutation = [0, 5, 1, 2, 4, 3]
-  func.return %transpose : tensor<1x6x2x3x5x4xf32>
+  func.return %transpose : tensor<7x6x2x3x5x4xf32>
 }
 
 // -----
 
 func.func @broadcast_transpose_fold_dynamic(%input: tensor<?x?x5xf32>,
-                                            %init1: tensor<1x?x3x?x5x6xf32>,
-                                            %init2: tensor<1x3x?x6x5x?xf32>) -> tensor<1x3x?x6x5x?xf32> {
+                                            %init1: tensor<2x?x3x?x5x6xf32>,
+                                            %init2: tensor<2x3x?x6x5x?xf32>) -> tensor<2x3x?x6x5x?xf32> {
   // CHECK-LABEL: @broadcast_transpose_fold_dynamic
   //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x5xf32>
-  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32>
-  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32>
+  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x?x3x?x5x6xf32>
+  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x?x6x5x?xf32>
   //   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
   //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
   //       CHECK:   %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x5xf32>
   //       CHECK:   %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x5xf32>
   //       CHECK:   %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor<?x5x?xf32>
   //       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<?x?x5xf32>) outs(%[[TMP_INIT]] : tensor<?x5x?xf32>) permutation = [1, 2, 0]
-  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
-  //       CHECK:   return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32>
+  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<2x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
+  //       CHECK:   return %[[BROADCAST]] : tensor<2x3x?x6x5x?xf32>
   %broadcast = linalg.broadcast
       ins(%input : tensor<?x?x5xf32>)
-      outs(%init1 : tensor<1x?x3x?x5x6xf32>)
+      outs(%init1 : tensor<2x?x3x?x5x6xf32>)
       dimensions = [0, 2, 5]
   %transpose = linalg.transpose
-      ins(%broadcast : tensor<1x?x3x?x5x6xf32>)
-      outs(%init2 : tensor<1x3x?x6x5x?xf32>)
+      ins(%broadcast : tensor<2x?x3x?x5x6xf32>)
+      outs(%init2 : tensor<2x3x?x6x5x?xf32>)
       permutation = [0, 2, 3, 5, 4, 1]
-  func.return %transpose : tensor<1x3x?x6x5x?xf32>
+  func.return %transpose : tensor<2x3x?x6x5x?xf32>
 }
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/106533


More information about the Mlir-commits mailing list