[Mlir-commits] [mlir] [MLIR] Folding unpack and pack sequence in data layout propagation (PR #138332)
Zhuoran Yin
llvmlistbot at llvm.org
Fri May 2 13:05:10 PDT 2025
https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/138332
>From 67da59818a7bcce97898b3f6aadff11262c65b95 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 2 May 2025 19:46:53 +0000
Subject: [PATCH] Folding unpack and pack sequence
---
.../Transforms/DataLayoutPropagation.cpp | 36 ++++++++++
.../Linalg/data-layout-propagation.mlir | 68 +++++++++----------
2 files changed, 69 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index f2a64f5bf38a3..893f9314396c8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -298,20 +298,56 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
return std::make_tuple(packedOperand, indexingMap);
}
+static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
+ int numDpsOuts = genericOp.getNumDpsInits();
+ for (int i = 0; i < numDpsOuts; ++i) {
+ Block *block = genericOp.getBody();
+ int numBlockArgs = block->getNumArguments();
+ int matchingInitArgIndex = numBlockArgs - numDpsOuts + i;
+ return block->getArgument(matchingInitArgIndex).use_empty();
+ }
+ return true;
+}
+
/// Pack a genericOp and return it.
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
Value dest, AffineMap packedOutIndexingMap,
const PackInfo &packInfo) {
Location loc = genericOp.getLoc();
SmallVector<Value> inputOperands;
+ SmallVector<Value> inputOperandsFromUnpackedSource;
SmallVector<AffineMap> indexingMaps;
+
+ // Note: canUnpackPackFold needs to also guarantee the generic body
+ // doesn't have gather semantics. Since such scenarios has been
+ // rejected by both BubbleUpPackOpThroughGenericOp and
+ // PushDownUnPackOpThroughGenericOp, we can safely assume
+ // canUnpackPackFold is as long as init is not used.
+ bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
rewriter, loc, packInfo, genericOp, inputOperand);
+
+ if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
+ inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
+ } else {
+ inputOperandsFromUnpackedSource.push_back(packedOperand);
+ }
+
inputOperands.push_back(packedOperand);
indexingMaps.push_back(packedIndexingMap);
}
+ // If The pack and unpack op can be folded:
+ // 1) use unpack op source op for operand to fold unpack -> pack sequence
+ // 2) init tensor of the generic op can be replaced by the new tensor.empty
+ // as the generic out.
+ if (canUnpackPackFold) {
+ inputOperands = inputOperandsFromUnpackedSource;
+ if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
+ dest = destPack.getDest();
+ }
+
int64_t numInnerLoops = packInfo.getNumTiledLoops();
SmallVector<utils::IteratorType> iterTypes =
genericOp.getIteratorTypesArray();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 19d4524a2ec06..fde1c40fb3c12 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -524,22 +524,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
// CHECK-LABEL: func.func @unpack_element_type_change
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
-// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
-// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME: ins(%[[ARG0_PACK]]
-// CHECK-SAME: outs(%[[ARG1_PACK]]
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1]]
@@ -564,19 +553,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
// CHECK-LABEL: func.func @forward_tensor_empty
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
-// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME: ins(%[[PACKED_ARG0]]
-// CHECK-SAME: outs(%[[DEST]]
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[FINAL_RES]]
@@ -810,12 +791,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
}
// CHECK-LABEL: func.func @unpack_empty_inner_dims
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
-// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
// CHECK: %[[RES:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[PACKED_ARG0]]
+// CHECK-SAME: ins(%[[ARG0]]
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
@@ -943,14 +921,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
-// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
-// CHECK: %[[PACK_ARG0:.+]] = linalg.pack
-// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
-// CHECK-SAME: into %[[PACK_EMPTY]]
// CHECK: %[[POOL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]]
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
// CHECK-SAME: outs(%[[INIT]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1395,27 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @fold_unpack_pack_after_bubble_up(%arg0: tensor<8x8x4x8xf32>) -> tensor<8x8x4x8xf32> {
+ %empty = tensor.empty() : tensor<32x64xf32>
+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty : tensor<8x8x4x8xf32> -> tensor<32x64xf32>
+ %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<32x64xf32>) outs(%empty : tensor<32x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %2 = arith.addf %in, %in : f32
+ linalg.yield %2 : f32
+ } -> tensor<32x64xf32>
+ %empty1 = tensor.empty() : tensor<8x8x4x8xf32>
+ %pack = linalg.pack %1 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty1 : tensor<32x64xf32> -> tensor<8x8x4x8xf32>
+ return %pack : tensor<8x8x4x8xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x8x4x8xf32>)
+// CHECK: return %[[GENERIC]] : tensor<8x8x4x8xf32>
More information about the Mlir-commits
mailing list