[Mlir-commits] [mlir] [mlir][linalg] Add pattern to propagate pack up through linalg.fill (PR #92097)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 14 04:16:08 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
The pattern allows to avoid packing result of a fill op and directly fill into a packed shaped output instead.
---
Full diff: https://github.com/llvm/llvm-project/pull/92097.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+42-5)
- (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+54)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 2bea083ac2d78..810ae8d1b0fe7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -730,6 +730,43 @@ class BubbleUpPackOpThroughReshapeOp final
ControlPropagationFn controlFn;
};
+/// Propagate a tensor.pack operation up through a linalg.fill. The idea is to
+/// avoid packing a fill op and create a 'packed' fill instead.
+class BubbleUpPackOpThroughFillOp final
+ : public OpRewritePattern<tensor::PackOp> {
+public:
+ BubbleUpPackOpThroughFillOp(MLIRContext *context, ControlPropagationFn fun)
+ : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
+
+ LogicalResult matchAndRewrite(tensor::PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ Value source = packOp.getSource();
+ auto fillOp = source.getDefiningOp<linalg::FillOp>();
+ if (!fillOp)
+ return failure();
+
+ // User controlled propagation function.
+ if (!controlFn(fillOp))
+ return failure();
+
+ if (!fillOp.getResult(0).hasOneUse())
+ return failure();
+
+ // Fill destination must be an empty tensor.
+ // Otherwise, packing cannot be removed.
+ if (!fillOp.getOutputs()[0].getDefiningOp<tensor::EmptyOp>())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<linalg::FillOp>(packOp, fillOp.getInputs(),
+ packOp.getDest());
+
+ return success();
+ }
+
+private:
+ ControlPropagationFn controlFn;
+};
+
/// Push down unpack op through expand shape op when the packed dims can be
/// projected to the dims after expanding. This is possible when the inner tile
/// sizes can divide the projected dims.
@@ -1074,9 +1111,9 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
void mlir::linalg::populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation) {
- patterns
- .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
- BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
- PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
- patterns.getContext(), controlPackUnPackPropagation);
+ patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
+ BubbleUpPackThroughPadOp, BubbleUpPackOpThroughReshapeOp,
+ BubbleUpPackOpThroughFillOp, PushDownUnPackOpThroughGenericOp,
+ PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
+ patterns.getContext(), controlPackUnPackPropagation);
}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index bee08503298fd..41c012fa29690 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1071,3 +1071,57 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
+
+// -----
+
+func.func @bubble_up_pack_through_fill(%arg0: f32) -> tensor<4x2x56x56x32xf32> {
+ %empty_fill = tensor.empty() : tensor<4x56x56x64xf32>
+ %fill = linalg.fill ins(%arg0 : f32) outs(%empty_fill : tensor<4x56x56x64xf32>) -> tensor<4x56x56x64xf32>
+ %empty_pack = tensor.empty() : tensor<4x2x56x56x32xf32>
+ %pack = tensor.pack %fill outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %empty_pack : tensor<4x56x56x64xf32> -> tensor<4x2x56x56x32xf32>
+ return %pack : tensor<4x2x56x56x32xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_pack_through_fill(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK: %[[EMPTY_PACKED:.+]] = tensor.empty() : tensor<4x2x56x56x32xf32>
+// CHECK: %[[FILL_PACKED:.+]] = linalg.fill
+// CHECK-SAME: ins(%[[ARG0]] : f32)
+// CHECK-SAME: outs(%[[EMPTY_PACKED]] : tensor<4x2x56x56x32xf32>)
+// CHECK: return %[[FILL_PACKED]]
+
+// -----
+
+func.func @bubble_up_pack_into_arg_through_fill(%arg0: f32, %arg1: tensor<4x2x56x56x32xf32>) -> tensor<4x2x56x56x32xf32> {
+ %empty_fill = tensor.empty() : tensor<4x56x56x64xf32>
+ %fill = linalg.fill ins(%arg0 : f32) outs(%empty_fill : tensor<4x56x56x64xf32>) -> tensor<4x56x56x64xf32>
+ %pack = tensor.pack %fill outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %arg1 : tensor<4x56x56x64xf32> -> tensor<4x2x56x56x32xf32>
+ return %pack : tensor<4x2x56x56x32xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_pack_into_arg_through_fill(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK-SAME: %[[ARG1:.+]]: tensor<4x2x56x56x32xf32>
+// CHECK: %[[FILL_PACKED_ARG:.+]] = linalg.fill
+// CHECK-SAME: ins(%[[ARG0]] : f32)
+// CHECK-SAME: outs(%[[ARG1]] : tensor<4x2x56x56x32xf32>)
+// CHECK: return %[[FILL_PACKED_ARG]]
+
+// -----
+
+func.func @no_bubble_up_pack_through_fill_into_arg(%arg0: f32, %arg1: tensor<4x56x56x64xf32>) -> tensor<4x2x56x56x32xf32> {
+ %fill = linalg.fill ins(%arg0 : f32) outs(%arg1 : tensor<4x56x56x64xf32>) -> tensor<4x56x56x64xf32>
+ %empty_pack = tensor.empty() : tensor<4x2x56x56x32xf32>
+ %pack = tensor.pack %fill outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %empty_pack : tensor<4x56x56x64xf32> -> tensor<4x2x56x56x32xf32>
+ return %pack : tensor<4x2x56x56x32xf32>
+}
+
+// CHECK-LABEL: func.func @no_bubble_up_pack_through_fill_into_arg(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK-SAME: %[[ARG1:.+]]: tensor<4x56x56x64xf32>
+// CHECK: %[[FILL_ARG:.+]] = linalg.fill
+// CHECK-SAME: ins(%[[ARG0]] : f32)
+// CHECK-SAME: outs(%[[ARG1]] : tensor<4x56x56x64xf32>)
+// CHECK: %[[EMPTY_PACK:.+]] = tensor.empty() : tensor<4x2x56x56x32xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[FILL_ARG]]{{.*}}into %[[EMPTY_PACK]]
+// CHECK: return %[[PACK]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/92097
More information about the Mlir-commits
mailing list