[Mlir-commits] [mlir] [mlir][linalg] Add pattern to propagate pack up through linalg.fill (PR #92097)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 14 04:16:08 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Adam Siemieniuk (adam-smnk)

<details>
<summary>Changes</summary>

The pattern allows to avoid packing result of a fill op and directly fill into a packed shaped output instead.

---
Full diff: https://github.com/llvm/llvm-project/pull/92097.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+42-5) 
- (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+54) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 2bea083ac2d78..810ae8d1b0fe7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -730,6 +730,43 @@ class BubbleUpPackOpThroughReshapeOp final
   ControlPropagationFn controlFn;
 };
 
+/// Propagate a tensor.pack operation up through a linalg.fill. The idea is to
+/// avoid packing a fill op and create a 'packed' fill instead.
+class BubbleUpPackOpThroughFillOp final
+    : public OpRewritePattern<tensor::PackOp> {
+public:
+  BubbleUpPackOpThroughFillOp(MLIRContext *context, ControlPropagationFn fun)
+      : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    Value source = packOp.getSource();
+    auto fillOp = source.getDefiningOp<linalg::FillOp>();
+    if (!fillOp)
+      return failure();
+
+    // User controlled propagation function.
+    if (!controlFn(fillOp))
+      return failure();
+
+    if (!fillOp.getResult(0).hasOneUse())
+      return failure();
+
+    // Fill destination must be an empty tensor.
+    // Otherwise, packing cannot be removed.
+    if (!fillOp.getOutputs()[0].getDefiningOp<tensor::EmptyOp>())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<linalg::FillOp>(packOp, fillOp.getInputs(),
+                                                packOp.getDest());
+
+    return success();
+  }
+
+private:
+  ControlPropagationFn controlFn;
+};
+
 /// Push down unpack op through expand shape op when the packed dims can be
 /// projected to the dims after expanding. This is possible when the inner tile
 /// sizes can divide the projected dims.
@@ -1074,9 +1111,9 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
 void mlir::linalg::populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
     const ControlPropagationFn &controlPackUnPackPropagation) {
-  patterns
-      .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
-              BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
-              PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
-          patterns.getContext(), controlPackUnPackPropagation);
+  patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
+                  BubbleUpPackThroughPadOp, BubbleUpPackOpThroughReshapeOp,
+                  BubbleUpPackOpThroughFillOp, PushDownUnPackOpThroughGenericOp,
+                  PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
+      patterns.getContext(), controlPackUnPackPropagation);
 }
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index bee08503298fd..41c012fa29690 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1071,3 +1071,57 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
 // CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<256x12x256xf32>
+
+// -----
+
+func.func @bubble_up_pack_through_fill(%arg0: f32) -> tensor<4x2x56x56x32xf32> {
+  %empty_fill = tensor.empty() : tensor<4x56x56x64xf32>
+  %fill = linalg.fill ins(%arg0 : f32) outs(%empty_fill : tensor<4x56x56x64xf32>) -> tensor<4x56x56x64xf32>
+  %empty_pack = tensor.empty() : tensor<4x2x56x56x32xf32>
+  %pack = tensor.pack %fill outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %empty_pack : tensor<4x56x56x64xf32> -> tensor<4x2x56x56x32xf32>
+  return %pack : tensor<4x2x56x56x32xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_pack_through_fill(
+// CHECK-SAME:     %[[ARG0:.+]]: f32
+// CHECK:         %[[EMPTY_PACKED:.+]] = tensor.empty() : tensor<4x2x56x56x32xf32>
+// CHECK:         %[[FILL_PACKED:.+]] = linalg.fill
+// CHECK-SAME:      ins(%[[ARG0]] : f32)
+// CHECK-SAME:      outs(%[[EMPTY_PACKED]] : tensor<4x2x56x56x32xf32>)
+// CHECK:         return %[[FILL_PACKED]]
+
+// -----
+
+func.func @bubble_up_pack_into_arg_through_fill(%arg0: f32, %arg1: tensor<4x2x56x56x32xf32>) -> tensor<4x2x56x56x32xf32> {
+  %empty_fill = tensor.empty() : tensor<4x56x56x64xf32>
+  %fill = linalg.fill ins(%arg0 : f32) outs(%empty_fill : tensor<4x56x56x64xf32>) -> tensor<4x56x56x64xf32>
+  %pack = tensor.pack %fill outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %arg1 : tensor<4x56x56x64xf32> -> tensor<4x2x56x56x32xf32>
+  return %pack : tensor<4x2x56x56x32xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_pack_into_arg_through_fill(
+// CHECK-SAME:     %[[ARG0:.+]]: f32
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<4x2x56x56x32xf32>
+// CHECK:         %[[FILL_PACKED_ARG:.+]] = linalg.fill
+// CHECK-SAME:      ins(%[[ARG0]] : f32)
+// CHECK-SAME:      outs(%[[ARG1]] : tensor<4x2x56x56x32xf32>)
+// CHECK:         return %[[FILL_PACKED_ARG]]
+
+// -----
+
+func.func @no_bubble_up_pack_through_fill_into_arg(%arg0: f32, %arg1: tensor<4x56x56x64xf32>) -> tensor<4x2x56x56x32xf32> {
+  %fill = linalg.fill ins(%arg0 : f32) outs(%arg1 : tensor<4x56x56x64xf32>) -> tensor<4x56x56x64xf32>
+  %empty_pack = tensor.empty() : tensor<4x2x56x56x32xf32>
+  %pack = tensor.pack %fill outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %empty_pack : tensor<4x56x56x64xf32> -> tensor<4x2x56x56x32xf32>
+  return %pack : tensor<4x2x56x56x32xf32>
+}
+
+// CHECK-LABEL: func.func @no_bubble_up_pack_through_fill_into_arg(
+// CHECK-SAME:     %[[ARG0:.+]]: f32
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<4x56x56x64xf32>
+// CHECK:         %[[FILL_ARG:.+]] = linalg.fill
+// CHECK-SAME:      ins(%[[ARG0]] : f32)
+// CHECK-SAME:      outs(%[[ARG1]] : tensor<4x56x56x64xf32>)
+// CHECK:         %[[EMPTY_PACK:.+]] = tensor.empty() : tensor<4x2x56x56x32xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[FILL_ARG]]{{.*}}into %[[EMPTY_PACK]]
+// CHECK:         return %[[PACK]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/92097


More information about the Mlir-commits mailing list