[Mlir-commits] [mlir] [MLIR] Folding unpack and pack sequence in data layout propagation (PR #138332)

Zhuoran Yin llvmlistbot at llvm.org
Fri May 2 13:03:28 PDT 2025


https://github.com/jerryyin created https://github.com/llvm/llvm-project/pull/138332

In `DataLayoutPropagation` patterns, it can populate sequence of unpack op followed by pack op. Such sequence tend to disrupt tiling and can be optimized. If there's guarantee that the generic op payload init tensor has no use, we can optimize the unpack/pack pair away. In particular:
 - `BubbleUpPackOpThroughGenericOp` pattern bubble up the pack op from after the generic op to before of it.
 - `PushDownUnPackOpThroughGenericOp` pattern push down the unpack op from before the generic op to after it.

In this both passes, if the operand of the generic op happen to come from an unpack, there's no need to create new packs of the generic operand. We can fold the unpack -> pack sequence and use the operand from the original source of the unpack op.

>From 92a01d3ca3b2d6589322910fb40a934b5d187c93 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 2 May 2025 19:46:53 +0000
Subject: [PATCH] Folding unpack and pack sequence

---
 .../Transforms/DataLayoutPropagation.cpp      | 36 ++++++++++
 .../Linalg/data-layout-propagation.mlir       | 70 +++++++++----------
 2 files changed, 70 insertions(+), 36 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index f2a64f5bf38a3..893f9314396c8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -298,20 +298,56 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
   return std::make_tuple(packedOperand, indexingMap);
 }
 
+static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
+  int numDpsOuts = genericOp.getNumDpsInits();
+  for (int i = 0; i < numDpsOuts; ++i) {
+    Block *block = genericOp.getBody();
+    int numBlockArgs = block->getNumArguments();
+    int matchingInitArgIndex = numBlockArgs - numDpsOuts + i;
+    return block->getArgument(matchingInitArgIndex).use_empty();
+  }
+  return true;
+}
+
 /// Pack a genericOp and return it.
 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
                                Value dest, AffineMap packedOutIndexingMap,
                                const PackInfo &packInfo) {
   Location loc = genericOp.getLoc();
   SmallVector<Value> inputOperands;
+  SmallVector<Value> inputOperandsFromUnpackedSource;
   SmallVector<AffineMap> indexingMaps;
+
+  // Note: canUnpackPackFold needs to also guarantee the generic body
+  // doesn't have gather semantics. Since such scenarios has been
+  // rejected by both BubbleUpPackOpThroughGenericOp and
+  // PushDownUnPackOpThroughGenericOp, we can safely assume
+  // canUnpackPackFold is as long as init is not used.
+  bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
   for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
     auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
         rewriter, loc, packInfo, genericOp, inputOperand);
+
+    if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
+      inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
+    } else {
+      inputOperandsFromUnpackedSource.push_back(packedOperand);
+    }
+
     inputOperands.push_back(packedOperand);
     indexingMaps.push_back(packedIndexingMap);
   }
 
+  // If The pack and unpack op can be folded:
+  // 1) use unpack op source op for operand to fold unpack -> pack sequence
+  // 2) init tensor of the generic op can be replaced by the new tensor.empty
+  // as the generic out.
+  if (canUnpackPackFold) {
+    inputOperands = inputOperandsFromUnpackedSource;
+    if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
+      dest = destPack.getDest();
+  }
+
   int64_t numInnerLoops = packInfo.getNumTiledLoops();
   SmallVector<utils::IteratorType> iterTypes =
       genericOp.getIteratorTypesArray();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 19d4524a2ec06..dddcba661bf56 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file --debug-only="linalg-data-layout-propagation" | FileCheck %s
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
@@ -524,22 +524,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
 // CHECK-LABEL: func.func @unpack_element_type_change
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
-// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK:         %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
-// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME:      ins(%[[ARG0_PACK]]
-// CHECK-SAME:      outs(%[[ARG1_PACK]]
+// CHECK-SAME:      ins(%[[ARG0]]
+// CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG1]]
@@ -564,19 +553,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK-LABEL: func.func @forward_tensor_empty
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
-// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME:      ins(%[[PACKED_ARG0]]
-// CHECK-SAME:      outs(%[[DEST]]
+// CHECK-SAME:      ins(%[[ARG0]]
+// CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[FINAL_RES]]
@@ -810,12 +791,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
 }
 
 // CHECK-LABEL: func.func @unpack_empty_inner_dims
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
-// CHECK:         %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
 // CHECK:         %[[RES:.+]] = linalg.generic
-// CHECK-SAME:      ins(%[[PACKED_ARG0]]
+// CHECK-SAME:      ins(%[[ARG0]]
 // CHECK:         %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
 
@@ -943,14 +921,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
 // CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
-// CHECK:         %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
-// CHECK:         %[[PACK_ARG0:.+]] = linalg.pack
-// CHECK-SAME:      inner_dims_pos = [1] inner_tiles = [16]
-// CHECK-SAME:      into %[[PACK_EMPTY]]
 // CHECK:         %[[POOL:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-// CHECK-SAME:      ins(%[[PACK_ARG0]], %[[ARG1]]
+// CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]]
 // CHECK-SAME:      outs(%[[INIT]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
 // CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1395,27 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
 // CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<256x12x256xf32>
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @fold_unpack_pack_after_bubble_up(%arg0: tensor<8x8x4x8xf32>) -> tensor<8x8x4x8xf32> {
+  %empty = tensor.empty() : tensor<32x64xf32>
+  %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty : tensor<8x8x4x8xf32> -> tensor<32x64xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<32x64xf32>) outs(%empty : tensor<32x64xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %2 = arith.addf %in, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<32x64xf32>
+  %empty1 = tensor.empty() : tensor<8x8x4x8xf32>
+  %pack = linalg.pack %1 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty1 : tensor<32x64xf32> -> tensor<8x8x4x8xf32>
+  return %pack : tensor<8x8x4x8xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
+// CHECK:         %[[GENERIC:.+]] = linalg.generic 
+// CHECK-SAME:    ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
+// CHECK-SAME:    outs(%[[EMPTY]] : tensor<8x8x4x8xf32>)
+// CHECK:         return %[[GENERIC]] : tensor<8x8x4x8xf32>



More information about the Mlir-commits mailing list