[Mlir-commits] [mlir] [mlir][linalg] Add a pattern to drop unit dim of `linalg.broadcast` (PR #106533)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 29 04:45:40 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This PR add a pattern to drop unit dim of `linalg.broadcast`. If the broadcasted shape size is 1, it can be droped.
e.g.
```
%0 = linalg.broadcast ins(%input : tensor<4xf32>)
outs(%init : tensor<1x4xf32>)
dimensions = [0]
```
converted to:
```
%collapsed = tensor.collapse_shape %init [[0, 1]] :
tensor<1x4xf32> into tensor<4xf32>
%0 = linalg.broadcast ins(%input : tensor<4xf32>)
outs(%collapsed : tensor<4xf32>)
dimensions = []
%expanded = tensor.expand_shape %0 [[0, 1]] output_shape [1, 4] :
tensor<4xf32> into tensor<1x4xf32>
```
---
Full diff: https://github.com/llvm/llvm-project/pull/106533.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+77-1)
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+65-20)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76df3ecf2d2bd4..da45abc682f129 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2134,9 +2134,85 @@ void BroadcastOp::getEffects(
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
+/// If the broadcasted shape size is 1, it can be droped.
+/// e.g.
+/// ```
+/// %0 = linalg.broadcast ins(%input : tensor<4xf32>)
+/// outs(%init : tensor<1x4xf32>)
+/// dimensions = [0]
+/// ```
+/// converted to:
+/// ```
+/// %collapsed = tensor.collapse_shape %init [[0, 1]] :
+/// tensor<1x4xf32> into tensor<4xf32>
+/// %0 = linalg.broadcast ins(%input : tensor<4xf32>)
+/// outs(%collapsed : tensor<4xf32>)
+/// dimensions = []
+/// %expanded = tensor.expand_shape %0 [[0, 1]] output_shape [1, 4] :
+/// tensor<4xf32> into tensor<1x4xf32>
+/// ```
+struct DropUnitDimOfBroadcastOp : OpRewritePattern<linalg::BroadcastOp> {
+ using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ if (!broadcastOp.hasPureTensorSemantics())
+ return failure();
+
+ auto init = broadcastOp.getInit();
+ auto initType = dyn_cast<RankedTensorType>(init.getType());
+ if (!initType)
+ return failure();
+ auto initShape = initType.getShape();
+ ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+ if (!llvm::any_of(dimensions,
+ [&](auto dim) { return initShape[dim] == 1; }))
+ return failure();
+
+ SmallVector<int64_t> newDimensions;
+ int64_t dropDim = 0;
+ // Adjust dimensions of broadcast.
+ for (int64_t dim : dimensions) {
+ if (initShape[dim] != 1) {
+ newDimensions.push_back(dim - dropDim);
+ } else {
+ ++dropDim;
+ }
+ }
+ SmallVector<ReassociationIndices> reassociation;
+ // Build reassociation indices by grouping consecutive size-1 dimensions.
+ bool needCollapse = false;
+ for (int64_t dim = 0; dim < initType.getRank(); ++dim) {
+ if (needCollapse) {
+ reassociation.back().push_back(dim);
+ } else {
+ reassociation.push_back({dim});
+ }
+ // Update the needCollapse flag.
+ needCollapse =
+ (initShape[dim] == 1 && llvm::is_contained(dimensions, dim));
+ }
+
+ Location loc = broadcastOp.getLoc();
+ auto collapsedType =
+ tensor::CollapseShapeOp::inferCollapsedType(initType, reassociation);
+ auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
+ loc, collapsedType, init, reassociation);
+ auto newBroadcast =
+ rewriter
+ .create<linalg::BroadcastOp>(loc, broadcastOp.getInput(),
+ collapsedInit, newDimensions)
+ .getResult()[0];
+ rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+ broadcastOp, initType, newBroadcast, reassociation);
+ return success();
+ }
+};
+
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
+ results.add<DropUnitDimOfBroadcastOp, EraseIdentityLinalgOp<BroadcastOp>>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 4bc2ed140da91a..19900203f0f5ff 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1039,6 +1039,51 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
// -----
+// CHECK-LABEL: func.func @broadcast_unit_shape_front_fold(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
+// CHECK: %[[VAL_2:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [1, 2, 3] : tensor<2x3xf32> into tensor<1x2x3xf32>
+// CHECK: return %[[VAL_2]] : tensor<1x2x3xf32>
+// CHECK: }
+func.func @broadcast_unit_shape_front_fold(%input: tensor<2x3xf32>, %init: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
+ %0 = linalg.broadcast ins(%input: tensor<2x3xf32>) outs(%init: tensor<1x2x3xf32>) dimensions = [0]
+ return %0 : tensor<1x2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_unit_shape_middle_fold(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x1x3xf32>) -> tensor<2x1x3xf32> {
+// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1, 2]] : tensor<2x1x3xf32> into tensor<2x3xf32>
+// CHECK: %[[VAL_3:.*]] = linalg.broadcast ins(%[[VAL_0]] : tensor<2xf32>) outs(%[[VAL_2]] : tensor<2x3xf32>) dimensions = [1]
+// CHECK: %[[VAL_4:.*]] = tensor.expand_shape %[[VAL_3]] {{\[\[}}0], [1, 2]] output_shape [2, 1, 3] : tensor<2x3xf32> into tensor<2x1x3xf32>
+// CHECK: return %[[VAL_4]] : tensor<2x1x3xf32>
+// CHECK: }
+func.func @broadcast_unit_shape_middle_fold(%input: tensor<2xf32>, %init: tensor<2x1x3xf32>) -> tensor<2x1x3xf32> {
+ %0 = linalg.broadcast ins(%input: tensor<2xf32>) outs(%init: tensor<2x1x3xf32>) dimensions = [1, 2]
+ return %0 : tensor<2x1x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_unit_shape_dynamic_fold(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x?xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x1x3x?xf32>) -> tensor<2x1x3x?xf32> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 3 : index
+// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1, 2], [3]] : tensor<2x1x3x?xf32> into tensor<2x3x?xf32>
+// CHECK: %[[VAL_4:.*]] = linalg.broadcast ins(%[[VAL_0]] : tensor<2x?xf32>) outs(%[[VAL_3]] : tensor<2x3x?xf32>) dimensions = [1]
+// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_2]] : tensor<2x1x3x?xf32>
+// CHECK: %[[VAL_6:.*]] = tensor.expand_shape %[[VAL_4]] {{\[\[}}0], [1, 2], [3]] output_shape [2, 1, 3, %[[VAL_5]]] : tensor<2x3x?xf32> into tensor<2x1x3x?xf32>
+// CHECK: return %[[VAL_6]] : tensor<2x1x3x?xf32>
+// CHECK: }
+func.func @broadcast_unit_shape_dynamic_fold(%input: tensor<2x?xf32>, %init: tensor<2x1x3x?xf32>) -> tensor<2x1x3x?xf32> {
+ %0 = linalg.broadcast ins(%input: tensor<2x?xf32>) outs(%init: tensor<2x1x3x?xf32>) dimensions = [1, 2]
+ return %0 : tensor<2x1x3x?xf32>
+}
+
+// -----
+
func.func @transpose_1d(%input: tensor<16xf32>,
%init: tensor<16xf32>) -> tensor<16xf32> {
%transpose = linalg.transpose
@@ -1119,53 +1164,53 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
// -----
func.func @broadcast_transpose_fold(%input: tensor<2x4x5xf32>,
- %init1: tensor<1x2x3x4x5x6xf32>,
- %init2: tensor<1x6x2x3x5x4xf32>) -> tensor<1x6x2x3x5x4xf32> {
+ %init1: tensor<7x2x3x4x5x6xf32>,
+ %init2: tensor<7x6x2x3x5x4xf32>) -> tensor<7x6x2x3x5x4xf32> {
// CHECK-LABEL: @broadcast_transpose_fold
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32>
- // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32>
- // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x6x2x3x5x4xf32>
+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<7x2x3x4x5x6xf32>
+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<7x6x2x3x5x4xf32>
// CHECK: %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32>
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1]
- // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x6x2x3x5x4xf32>) dimensions = [0, 3, 1]
- // CHECK: return %[[BROADCAST]] : tensor<1x6x2x3x5x4xf32>
+ // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<7x6x2x3x5x4xf32>) dimensions = [0, 3, 1]
+ // CHECK: return %[[BROADCAST]] : tensor<7x6x2x3x5x4xf32>
%broadcast = linalg.broadcast
ins(%input : tensor<2x4x5xf32>)
- outs(%init1 : tensor<1x2x3x4x5x6xf32>)
+ outs(%init1 : tensor<7x2x3x4x5x6xf32>)
dimensions = [0, 2, 5]
%transpose = linalg.transpose
- ins(%broadcast : tensor<1x2x3x4x5x6xf32>)
- outs(%init2 : tensor<1x6x2x3x5x4xf32>)
+ ins(%broadcast : tensor<7x2x3x4x5x6xf32>)
+ outs(%init2 : tensor<7x6x2x3x5x4xf32>)
permutation = [0, 5, 1, 2, 4, 3]
- func.return %transpose : tensor<1x6x2x3x5x4xf32>
+ func.return %transpose : tensor<7x6x2x3x5x4xf32>
}
// -----
func.func @broadcast_transpose_fold_dynamic(%input: tensor<?x?x5xf32>,
- %init1: tensor<1x?x3x?x5x6xf32>,
- %init2: tensor<1x3x?x6x5x?xf32>) -> tensor<1x3x?x6x5x?xf32> {
+ %init1: tensor<2x?x3x?x5x6xf32>,
+ %init2: tensor<2x3x?x6x5x?xf32>) -> tensor<2x3x?x6x5x?xf32> {
// CHECK-LABEL: @broadcast_transpose_fold_dynamic
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x5xf32>
- // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32>
- // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32>
+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x?x3x?x5x6xf32>
+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x?x6x5x?xf32>
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x5xf32>
// CHECK: %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x5xf32>
// CHECK: %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor<?x5x?xf32>
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<?x?x5xf32>) outs(%[[TMP_INIT]] : tensor<?x5x?xf32>) permutation = [1, 2, 0]
- // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
- // CHECK: return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32>
+ // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<2x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
+ // CHECK: return %[[BROADCAST]] : tensor<2x3x?x6x5x?xf32>
%broadcast = linalg.broadcast
ins(%input : tensor<?x?x5xf32>)
- outs(%init1 : tensor<1x?x3x?x5x6xf32>)
+ outs(%init1 : tensor<2x?x3x?x5x6xf32>)
dimensions = [0, 2, 5]
%transpose = linalg.transpose
- ins(%broadcast : tensor<1x?x3x?x5x6xf32>)
- outs(%init2 : tensor<1x3x?x6x5x?xf32>)
+ ins(%broadcast : tensor<2x?x3x?x5x6xf32>)
+ outs(%init2 : tensor<2x3x?x6x5x?xf32>)
permutation = [0, 2, 3, 5, 4, 1]
- func.return %transpose : tensor<1x3x?x6x5x?xf32>
+ func.return %transpose : tensor<2x3x?x6x5x?xf32>
}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/106533
More information about the Mlir-commits
mailing list