[Mlir-commits] [mlir] [MLIR] Folding unpack and pack sequence in data layout propagation (PR #138332)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 2 13:04:03 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Zhuoran Yin (jerryyin)

<details>
<summary>Changes</summary>

In `DataLayoutPropagation` patterns, it can populate sequence of unpack op followed by pack op. Such sequence tend to disrupt tiling and can be optimized. If there's guarantee that the generic op payload init tensor has no use, we can optimize the unpack/pack pair away. In particular:
 - `BubbleUpPackOpThroughGenericOp` pattern bubble up the pack op from after the generic op to before of it.
 - `PushDownUnPackOpThroughGenericOp` pattern push down the unpack op from before the generic op to after it.

In this both passes, if the operand of the generic op happen to come from an unpack, there's no need to create new packs of the generic operand. We can fold the unpack -> pack sequence and use the operand from the original source of the unpack op.

---
Full diff: https://github.com/llvm/llvm-project/pull/138332.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+36) 
- (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+34-36) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index f2a64f5bf38a3..893f9314396c8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -298,20 +298,56 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
   return std::make_tuple(packedOperand, indexingMap);
 }
 
+static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
+  int numDpsOuts = genericOp.getNumDpsInits();
+  for (int i = 0; i < numDpsOuts; ++i) {
+    Block *block = genericOp.getBody();
+    int numBlockArgs = block->getNumArguments();
+    int matchingInitArgIndex = numBlockArgs - numDpsOuts + i;
+    return block->getArgument(matchingInitArgIndex).use_empty();
+  }
+  return true;
+}
+
 /// Pack a genericOp and return it.
 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
                                Value dest, AffineMap packedOutIndexingMap,
                                const PackInfo &packInfo) {
   Location loc = genericOp.getLoc();
   SmallVector<Value> inputOperands;
+  SmallVector<Value> inputOperandsFromUnpackedSource;
   SmallVector<AffineMap> indexingMaps;
+
+  // Note: canUnpackPackFold needs to also guarantee the generic body
+  // doesn't have gather semantics. Since such scenarios has been
+  // rejected by both BubbleUpPackOpThroughGenericOp and
+  // PushDownUnPackOpThroughGenericOp, we can safely assume
+  // canUnpackPackFold is as long as init is not used.
+  bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
   for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
     auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
         rewriter, loc, packInfo, genericOp, inputOperand);
+
+    if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
+      inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
+    } else {
+      inputOperandsFromUnpackedSource.push_back(packedOperand);
+    }
+
     inputOperands.push_back(packedOperand);
     indexingMaps.push_back(packedIndexingMap);
   }
 
+  // If The pack and unpack op can be folded:
+  // 1) use unpack op source op for operand to fold unpack -> pack sequence
+  // 2) init tensor of the generic op can be replaced by the new tensor.empty
+  // as the generic out.
+  if (canUnpackPackFold) {
+    inputOperands = inputOperandsFromUnpackedSource;
+    if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
+      dest = destPack.getDest();
+  }
+
   int64_t numInnerLoops = packInfo.getNumTiledLoops();
   SmallVector<utils::IteratorType> iterTypes =
       genericOp.getIteratorTypesArray();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 19d4524a2ec06..dddcba661bf56 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file --debug-only="linalg-data-layout-propagation" | FileCheck %s
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
@@ -524,22 +524,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
 // CHECK-LABEL: func.func @unpack_element_type_change
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
-// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK:         %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
-// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME:      ins(%[[ARG0_PACK]]
-// CHECK-SAME:      outs(%[[ARG1_PACK]]
+// CHECK-SAME:      ins(%[[ARG0]]
+// CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG1]]
@@ -564,19 +553,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK-LABEL: func.func @forward_tensor_empty
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
-// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME:      ins(%[[PACKED_ARG0]]
-// CHECK-SAME:      outs(%[[DEST]]
+// CHECK-SAME:      ins(%[[ARG0]]
+// CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[FINAL_RES]]
@@ -810,12 +791,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
 }
 
 // CHECK-LABEL: func.func @unpack_empty_inner_dims
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
-// CHECK:         %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
 // CHECK:         %[[RES:.+]] = linalg.generic
-// CHECK-SAME:      ins(%[[PACKED_ARG0]]
+// CHECK-SAME:      ins(%[[ARG0]]
 // CHECK:         %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
 
@@ -943,14 +921,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
 // CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
-// CHECK:         %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
-// CHECK:         %[[PACK_ARG0:.+]] = linalg.pack
-// CHECK-SAME:      inner_dims_pos = [1] inner_tiles = [16]
-// CHECK-SAME:      into %[[PACK_EMPTY]]
 // CHECK:         %[[POOL:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-// CHECK-SAME:      ins(%[[PACK_ARG0]], %[[ARG1]]
+// CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]]
 // CHECK-SAME:      outs(%[[INIT]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
 // CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1395,27 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
 // CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<256x12x256xf32>
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @fold_unpack_pack_after_bubble_up(%arg0: tensor<8x8x4x8xf32>) -> tensor<8x8x4x8xf32> {
+  %empty = tensor.empty() : tensor<32x64xf32>
+  %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty : tensor<8x8x4x8xf32> -> tensor<32x64xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<32x64xf32>) outs(%empty : tensor<32x64xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %2 = arith.addf %in, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<32x64xf32>
+  %empty1 = tensor.empty() : tensor<8x8x4x8xf32>
+  %pack = linalg.pack %1 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty1 : tensor<32x64xf32> -> tensor<8x8x4x8xf32>
+  return %pack : tensor<8x8x4x8xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
+// CHECK:         %[[GENERIC:.+]] = linalg.generic 
+// CHECK-SAME:    ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
+// CHECK-SAME:    outs(%[[EMPTY]] : tensor<8x8x4x8xf32>)
+// CHECK:         return %[[GENERIC]] : tensor<8x8x4x8xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/138332


More information about the Mlir-commits mailing list