[Mlir-commits] [mlir] b9d6cbd - [MLIR] Folding unpack and pack sequence in data layout propagation from padded domain (#138332)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 7 08:14:46 PDT 2025
Author: Zhuoran Yin
Date: 2025-05-07T11:14:42-04:00
New Revision: b9d6cbd4dc1def3f15b7d5ebb8cb4714bdad22bf
URL: https://github.com/llvm/llvm-project/commit/b9d6cbd4dc1def3f15b7d5ebb8cb4714bdad22bf
DIFF: https://github.com/llvm/llvm-project/commit/b9d6cbd4dc1def3f15b7d5ebb8cb4714bdad22bf.diff
LOG: [MLIR] Folding unpack and pack sequence in data layout propagation from padded domain (#138332)
In `DataLayoutPropagation` patterns, it can populate sequence of unpack
op followed by pack op. Such sequence tend to disrupt tiling and can be
optimized. This is especially true for pack and unpack in padded values.
The idea of this patch is to optimize the propagation by never creating
the unpack + pack in cases where the padding value does not matter for
the op that is being propagated through. We can optimize the unpack/pack
pair away from in particular `PushDownUnPackOpThroughGenericOp` pattern.
If the operand of the generic op happen to come from an unpack, there's
no need to create new packs of the generic operand. We can fold the
unpack -> pack sequence and use the operand from the original source of
the unpack op.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/test/Dialect/Linalg/data-layout-propagation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index f2a64f5bf38a3..26904f1f40d12 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -298,20 +298,42 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
return std::make_tuple(packedOperand, indexingMap);
}
-/// Pack a genericOp and return it.
+/// This function is a helper subroutine to pack a genericOp and return it. It
+/// will create a new generic op with the packed operand and the packed output
+/// according to packInfo when we attempt to push down unpack or bubble up pack
+/// around it. Implicitly this will only work when a packInfo can be obtained.
+/// This make sure that we are only using this function on parallel permuted
+/// dimensions.
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
Value dest, AffineMap packedOutIndexingMap,
- const PackInfo &packInfo) {
+ const PackInfo &packInfo,
+ bool isFoldableUnpackPack) {
Location loc = genericOp.getLoc();
SmallVector<Value> inputOperands;
+ SmallVector<Value> inputOperandsFromUnpackedSource;
SmallVector<AffineMap> indexingMaps;
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
rewriter, loc, packInfo, genericOp, inputOperand);
+ if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
+ inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
+ } else {
+ inputOperandsFromUnpackedSource.push_back(packedOperand);
+ }
inputOperands.push_back(packedOperand);
indexingMaps.push_back(packedIndexingMap);
}
+ // If the pack and unpack op can be folded:
+ // 1) use unpack op source op for operand to fold unpack -> pack sequence.
+ // 2) init tensor of the generic op can be replaced by the destination of the
+ // pack op.
+ if (isFoldableUnpackPack) {
+ inputOperands = inputOperandsFromUnpackedSource;
+ if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
+ dest = destPack.getDest();
+ }
+
int64_t numInnerLoops = packInfo.getNumTiledLoops();
SmallVector<utils::IteratorType> iterTypes =
genericOp.getIteratorTypesArray();
@@ -447,8 +469,10 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
.getDefiningOp<tensor::EmptyOp>()) {
dest = packOpDest;
}
+ // pack(unpack) isn't naively foldable because the unpack op can be from
+ // an arbitrary domain so we need to keep both.
return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
- *packInfo);
+ *packInfo, /*isFoldableUnpackPack=*/false);
}
/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
@@ -1085,8 +1109,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
}
// Pack the genericOp.
+ // pack(unpack) is foldable in this case. This is because in pushing down the
+ // unpack, by default we will populate an additional pack op after the unpack.
+ // This guarantees them to be foldable.
GenericOp newGenericOp =
- packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
+ packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
+ /*isFoldableUnpackPack=*/true);
Value newResult =
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 19d4524a2ec06..63f068d3f8681 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -455,13 +455,10 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
-// CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_EMPTY_PACK]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]]]
-// CHECK-SAME: outs(%[[PACKED_ARG0]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[UNPACKED_ARG0]]
@@ -485,22 +482,11 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
// CHECK-LABEL: func.func @unpack_on_input
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
-// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
-// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME: ins(%[[ARG0_PACK]]
-// CHECK-SAME: outs(%[[ARG1_PACK]]
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1]]
@@ -524,22 +510,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
// CHECK-LABEL: func.func @unpack_element_type_change
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
-// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
-// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME: ins(%[[ARG0_PACK]]
-// CHECK-SAME: outs(%[[ARG1_PACK]]
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1]]
@@ -564,19 +539,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
// CHECK-LABEL: func.func @forward_tensor_empty
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
-// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME: ins(%[[PACKED_ARG0]]
-// CHECK-SAME: outs(%[[DEST]]
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[FINAL_RES]]
@@ -810,12 +777,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
}
// CHECK-LABEL: func.func @unpack_empty_inner_dims
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
-// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
// CHECK: %[[RES:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[PACKED_ARG0]]
+// CHECK-SAME: ins(%[[ARG0]]
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
@@ -943,14 +907,10 @@ func.func @unpack_
diff erent_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
-// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
-// CHECK: %[[PACK_ARG0:.+]] = linalg.pack
-// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
-// CHECK-SAME: into %[[PACK_EMPTY]]
// CHECK: %[[POOL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]]
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
// CHECK-SAME: outs(%[[INIT]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1381,48 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
+
+// -----
+
+func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %dest: tensor<?x64xf32>, %arg1: tensor<?x64xbf16>) -> tensor<?x64xbf16> {
+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %dest : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x64xbf16>
+ return %0 : tensor<?x64xbf16>
+}
+// CHECK-LABEL: func.func @push_unpack_in_padded_domain_foldable
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
+// CHECK-SAME: into %[[ARG2]]
+// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
+
+// -----
+
+func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor<?x64xf32>) -> tensor<?x64xf32> {
+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %arg1 : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %1 = arith.addf %in, %out : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x64xf32>
+ return %0 : tensor<?x64xf32>
+}
+// CHECK-LABEL: func.func @push_unpack_in_padded_domain_out_used
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xf32>)
+// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
+// CHECK-SAME: into %[[ARG1]]
+// CHECK: return %[[UNPACK2]] : tensor<?x64xf32>
More information about the Mlir-commits
mailing list