[Mlir-commits] [mlir] [mlir][Linalg] Allow propagation of pack through multi use pad (PR #98039)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 8 09:08:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Quinn Dawkins (qedawkins)

<details>
<summary>Changes</summary>

This allows bubbling `tensor.pack` through `tensor.pad` when the pad has multiple uses. A new pad is created and a `tensor.unpack` is inserted to connect the packed pad with the new users.

To keep the previous behavior, the layout propagation control function can be modified to disallow multi-use propagation.

---
Full diff: https://github.com/llvm/llvm-project/pull/98039.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+19-8) 
- (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+48-25) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6984bc2dff498..5f7cf30335e99 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -491,9 +491,6 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
     if (!controlFn(padOp))
       return failure();
 
-    if (!padOp.getResult().hasOneUse())
-      return failure();
-
     // TODO: Enable padding when the padding values are the same.
     if (packOp.getPaddingValue())
       return failure();
@@ -510,7 +507,6 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
       return failure();
 
     ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
-    ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
 
     // Bail out if one of the padded dimension is a tiled one.
     llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
@@ -524,11 +520,13 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
     OpBuilder::InsertionGuard guard(rewriter);
     rewriter.setInsertionPoint(padOp);
 
+    ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+    SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
     auto empty = tensor::PackOp::createDestinationTensor(
-        rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
+        rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
         outerDimsPerm);
-    Value packedSource = rewriter.create<tensor::PackOp>(
-        loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
+    auto sourcePack = rewriter.create<tensor::PackOp>(
+        loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
         /*padding=*/std::nullopt, outerDimsPerm);
 
     // If we have `outer_dims_perms` we need to adjust the padded dimensions.
@@ -545,9 +543,22 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
     highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
 
     auto newPadOp = rewriter.create<tensor::PadOp>(
-        loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal,
+        loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal,
         padOp.getNofold());
+
+    // If the pad has more than one user, create an unpack on the new pad to
+    // replace the other uses.
+    if (!padOp->hasOneUse()) {
+      auto unpackEmpty = tensor::UnPackOp::createDestinationTensor(
+          rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
+      Value unpackedPad = rewriter.create<tensor::UnPackOp>(
+          loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
+      rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
+    }
+
+    // Replace the pack with the new pad.
     rewriter.replaceOp(packOp, newPadOp.getResult());
+
     return success();
   }
 
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 626dd8b697e59..d9206432379fb 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -458,23 +458,23 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 // CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
 // CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:      ins(%[[ARG0_PACK]]
 // CHECK-SAME:      outs(%[[ARG1_PACK]]
-// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 
 // -----
@@ -537,20 +537,20 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK-LABEL: func.func @forward_tensor_empty
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 // CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
 // CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:      ins(%[[PACKED_ARG0]]
 // CHECK-SAME:      outs(%[[DEST]]
 // CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 
 // -----
@@ -571,8 +571,8 @@ func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tens
 // CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
 // CHECK:         %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
 // CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
-// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[PADDED]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
 
 // -----
@@ -614,8 +614,8 @@ func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x5
 // CHECK:         %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
 // CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
 // CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
-// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
 // CHECK:         %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]
 
@@ -687,6 +687,29 @@ func.func @pad_along_packed_dim(%arg0: tensor<1x60x56x56xf32>) -> tensor<1x2x58x
 
 // -----
 
+func.func @multi_use_pad_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> (tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
+    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32>
+  %0 = tensor.empty() : tensor<1x2x58x58x32xf32>
+  %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32>
+  return %padded, %1 : tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>
+}
+
+// CHECK-LABEL: func.func @multi_use_pad_pack_propagation(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x64x56x56xf32>)
+// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x56x56x32xf32>
+// CHECK:         %[[PACKED:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [32]
+// CHECK-SAME:      into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x2x56x56x32xf32>
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[PADDED]] inner_dims_pos = [1] inner_tiles = [32]
+// CHECK:         return %[[UNPACKED]], %[[PADDED]]
+
+// -----
+
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
   %init = tensor.empty() : tensor<128x256xi32>
@@ -713,7 +736,7 @@ func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x3
 // CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
 // CHECK-NEXT:    %{{.+}} = tensor.pack %[[GEN]]
-// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32] 
+// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32]
 // CHECK-SAME:      into %[[ALLOC]]
 
 // -----
@@ -760,19 +783,19 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
 
 // CHECK-LABEL: func.func @unpack_empty_inner_dims
 // CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
-// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      ins(%[[PACKED_ARG0]]
 // CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
 
 // -----
 
 #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>, 
+func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
       %arg1: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
   %elem = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]}
       ins(%arg0 : tensor<128x256x32xi32>)
@@ -810,7 +833,7 @@ func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
 
 // -----
 
-func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, 
+func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>,
   %arg2: tensor<128xi32>, %init_reduction: tensor<100x128x256xi32>) -> tensor<4x16x100x16x32xi32>
 {
   %reduction = linalg.generic {
@@ -867,7 +890,7 @@ func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %a
 #map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>
 #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
 #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d3)>
-func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>, 
+func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>,
     %filter: tensor<2x2xi32>) -> tensor<16x540x960xi32>{
   %init = tensor.empty() : tensor<16x540x960xi32>
   %empty = tensor.empty() : tensor<1x16x1080x1920xi32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/98039


More information about the Mlir-commits mailing list