[Mlir-commits] [mlir] [mlir][Linalg] Allow propagation of pack through multi use pad (PR #98039)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 8 09:08:34 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg
Author: Quinn Dawkins (qedawkins)
<details>
<summary>Changes</summary>
This allows bubbling `tensor.pack` through `tensor.pad` when the pad has multiple uses. A new pad is created and a `tensor.unpack` is inserted to connect the packed pad with the new users.
To keep the previous behavior, the layout propagation control function can be modified to disallow multi-use propagation.
---
Full diff: https://github.com/llvm/llvm-project/pull/98039.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+19-8)
- (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+48-25)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6984bc2dff498..5f7cf30335e99 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -491,9 +491,6 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
if (!controlFn(padOp))
return failure();
- if (!padOp.getResult().hasOneUse())
- return failure();
-
// TODO: Enable padding when the padding values are the same.
if (packOp.getPaddingValue())
return failure();
@@ -510,7 +507,6 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
return failure();
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
- ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
// Bail out if one of the padded dimension is a tiled one.
llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
@@ -524,11 +520,13 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(padOp);
+ ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+ SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
auto empty = tensor::PackOp::createDestinationTensor(
- rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
+ rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
outerDimsPerm);
- Value packedSource = rewriter.create<tensor::PackOp>(
- loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
+ auto sourcePack = rewriter.create<tensor::PackOp>(
+ loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
/*padding=*/std::nullopt, outerDimsPerm);
// If we have `outer_dims_perms` we need to adjust the padded dimensions.
@@ -545,9 +543,22 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
auto newPadOp = rewriter.create<tensor::PadOp>(
- loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal,
+ loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal,
padOp.getNofold());
+
+ // If the pad has more than one user, create an unpack on the new pad to
+ // replace the other uses.
+ if (!padOp->hasOneUse()) {
+ auto unpackEmpty = tensor::UnPackOp::createDestinationTensor(
+ rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
+ Value unpackedPad = rewriter.create<tensor::UnPackOp>(
+ loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
+ rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
+ }
+
+ // Replace the pack with the new pad.
rewriter.replaceOp(packOp, newPadOp.getResult());
+
return success();
}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 626dd8b697e59..d9206432379fb 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -458,23 +458,23 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[ARG0_PACK]]
// CHECK-SAME: outs(%[[ARG1_PACK]]
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// -----
@@ -537,20 +537,20 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
// CHECK-LABEL: func.func @forward_tensor_empty
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK-SAME: outs(%[[DEST]]
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// -----
@@ -571,8 +571,8 @@ func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tens
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
// -----
@@ -614,8 +614,8 @@ func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x5
// CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]
@@ -687,6 +687,29 @@ func.func @pad_along_packed_dim(%arg0: tensor<1x60x56x56xf32>) -> tensor<1x2x58x
// -----
+func.func @multi_use_pad_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> (tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32>
+ %0 = tensor.empty() : tensor<1x2x58x58x32xf32>
+ %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32>
+ return %padded, %1 : tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>
+}
+
+// CHECK-LABEL: func.func @multi_use_pad_pack_propagation(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x64x56x56xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x56x56x32xf32>
+// CHECK: %[[PACKED:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [32]
+// CHECK-SAME: into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x2x56x56x32xf32>
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[PADDED]] inner_dims_pos = [1] inner_tiles = [32]
+// CHECK: return %[[UNPACKED]], %[[PADDED]]
+
+// -----
+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
%init = tensor.empty() : tensor<128x256xi32>
@@ -713,7 +736,7 @@ func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x3
// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
// CHECK-NEXT: %{{.+}} = tensor.pack %[[GEN]]
-// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
// CHECK-SAME: into %[[ALLOC]]
// -----
@@ -760,19 +783,19 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
// CHECK-LABEL: func.func @unpack_empty_inner_dims
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
// -----
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
+func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
%arg1: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
%elem = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0 : tensor<128x256x32xi32>)
@@ -810,7 +833,7 @@ func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
// -----
-func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>,
+func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>,
%arg2: tensor<128xi32>, %init_reduction: tensor<100x128x256xi32>) -> tensor<4x16x100x16x32xi32>
{
%reduction = linalg.generic {
@@ -867,7 +890,7 @@ func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %a
#map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d3)>
-func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>,
+func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>,
%filter: tensor<2x2xi32>) -> tensor<16x540x960xi32>{
%init = tensor.empty() : tensor<16x540x960xi32>
%empty = tensor.empty() : tensor<1x16x1080x1920xi32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/98039
More information about the Mlir-commits
mailing list