[Mlir-commits] [mlir] [MLIR] Folding unpack and pack sequence in data layout propagation (PR #138332)

Zhuoran Yin llvmlistbot at llvm.org
Fri May 2 13:05:10 PDT 2025


https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/138332

>From 67da59818a7bcce97898b3f6aadff11262c65b95 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 2 May 2025 19:46:53 +0000
Subject: [PATCH] Folding unpack and pack sequence

---
 .../Transforms/DataLayoutPropagation.cpp      | 36 ++++++++++
 .../Linalg/data-layout-propagation.mlir       | 68 +++++++++----------
 2 files changed, 69 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index f2a64f5bf38a3..893f9314396c8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -298,20 +298,56 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
   return std::make_tuple(packedOperand, indexingMap);
 }
 
+static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
+  int numDpsOuts = genericOp.getNumDpsInits();
+  for (int i = 0; i < numDpsOuts; ++i) {
+    Block *block = genericOp.getBody();
+    int numBlockArgs = block->getNumArguments();
+    int matchingInitArgIndex = numBlockArgs - numDpsOuts + i;
+    return block->getArgument(matchingInitArgIndex).use_empty();
+  }
+  return true;
+}
+
 /// Pack a genericOp and return it.
 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
                                Value dest, AffineMap packedOutIndexingMap,
                                const PackInfo &packInfo) {
   Location loc = genericOp.getLoc();
   SmallVector<Value> inputOperands;
+  SmallVector<Value> inputOperandsFromUnpackedSource;
   SmallVector<AffineMap> indexingMaps;
+
+  // Note: canUnpackPackFold needs to also guarantee the generic body
+  // doesn't have gather semantics. Since such scenarios has been
+  // rejected by both BubbleUpPackOpThroughGenericOp and
+  // PushDownUnPackOpThroughGenericOp, we can safely assume
+  // canUnpackPackFold is as long as init is not used.
+  bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
   for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
     auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
         rewriter, loc, packInfo, genericOp, inputOperand);
+
+    if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
+      inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
+    } else {
+      inputOperandsFromUnpackedSource.push_back(packedOperand);
+    }
+
     inputOperands.push_back(packedOperand);
     indexingMaps.push_back(packedIndexingMap);
   }
 
+  // If The pack and unpack op can be folded:
+  // 1) use unpack op source op for operand to fold unpack -> pack sequence
+  // 2) init tensor of the generic op can be replaced by the new tensor.empty
+  // as the generic out.
+  if (canUnpackPackFold) {
+    inputOperands = inputOperandsFromUnpackedSource;
+    if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
+      dest = destPack.getDest();
+  }
+
   int64_t numInnerLoops = packInfo.getNumTiledLoops();
   SmallVector<utils::IteratorType> iterTypes =
       genericOp.getIteratorTypesArray();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 19d4524a2ec06..fde1c40fb3c12 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -524,22 +524,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
 // CHECK-LABEL: func.func @unpack_element_type_change
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
-// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK:         %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
-// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME:      ins(%[[ARG0_PACK]]
-// CHECK-SAME:      outs(%[[ARG1_PACK]]
+// CHECK-SAME:      ins(%[[ARG0]]
+// CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG1]]
@@ -564,19 +553,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK-LABEL: func.func @forward_tensor_empty
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
-// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME:      ins(%[[PACKED_ARG0]]
-// CHECK-SAME:      outs(%[[DEST]]
+// CHECK-SAME:      ins(%[[ARG0]]
+// CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[FINAL_RES]]
@@ -810,12 +791,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
 }
 
 // CHECK-LABEL: func.func @unpack_empty_inner_dims
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
-// CHECK:         %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
 // CHECK:         %[[RES:.+]] = linalg.generic
-// CHECK-SAME:      ins(%[[PACKED_ARG0]]
+// CHECK-SAME:      ins(%[[ARG0]]
 // CHECK:         %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
 
@@ -943,14 +921,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
 // CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
-// CHECK:         %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
-// CHECK:         %[[PACK_ARG0:.+]] = linalg.pack
-// CHECK-SAME:      inner_dims_pos = [1] inner_tiles = [16]
-// CHECK-SAME:      into %[[PACK_EMPTY]]
 // CHECK:         %[[POOL:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-// CHECK-SAME:      ins(%[[PACK_ARG0]], %[[ARG1]]
+// CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]]
 // CHECK-SAME:      outs(%[[INIT]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
 // CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1395,27 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
 // CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<256x12x256xf32>
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @fold_unpack_pack_after_bubble_up(%arg0: tensor<8x8x4x8xf32>) -> tensor<8x8x4x8xf32> {
+  %empty = tensor.empty() : tensor<32x64xf32>
+  %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty : tensor<8x8x4x8xf32> -> tensor<32x64xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<32x64xf32>) outs(%empty : tensor<32x64xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %2 = arith.addf %in, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<32x64xf32>
+  %empty1 = tensor.empty() : tensor<8x8x4x8xf32>
+  %pack = linalg.pack %1 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty1 : tensor<32x64xf32> -> tensor<8x8x4x8xf32>
+  return %pack : tensor<8x8x4x8xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
+// CHECK:         %[[GENERIC:.+]] = linalg.generic 
+// CHECK-SAME:    ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
+// CHECK-SAME:    outs(%[[EMPTY]] : tensor<8x8x4x8xf32>)
+// CHECK:         return %[[GENERIC]] : tensor<8x8x4x8xf32>



More information about the Mlir-commits mailing list