[Mlir-commits] [mlir] 7d040d4 - [mlir][linalg] Handle outer_dims_perm in linalg.pack consumer fusion. (#149426)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 18 09:42:43 PDT 2025
Author: Han-Chung Wang
Date: 2025-07-18T09:42:40-07:00
New Revision: 7d040d4675baf6881cf50c4dba78cc18af85f9ef
URL: https://github.com/llvm/llvm-project/commit/7d040d4675baf6881cf50c4dba78cc18af85f9ef
DIFF: https://github.com/llvm/llvm-project/commit/7d040d4675baf6881cf50c4dba78cc18af85f9ef.diff
LOG: [mlir][linalg] Handle outer_dims_perm in linalg.pack consumer fusion. (#149426)
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 5a10883a6043c..b059bcc025315 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -893,6 +893,13 @@ struct PackOpTiling
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
+ SmallVector<int64_t> outerShapeWithoutTranspose(
+ packOp.getDestType().getShape().take_front(packOp.getSourceRank()));
+ if (!packOp.getOuterDimsPerm().empty()) {
+ applyPermutationToVector(
+ outerShapeWithoutTranspose,
+ invertPermutationVector(packOp.getOuterDimsPerm()));
+ }
for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
if (dimAndTileMapping.count(dim)) {
FailureOr<int64_t> cstTileSize =
@@ -908,7 +915,7 @@ struct PackOpTiling
// TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
// hard check to determine if a dimension is tiled or not.
int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
- int64_t destDimSize = packOp.getDestType().getDimSize(dim);
+ int64_t destDimSize = outerShapeWithoutTranspose[dim];
bool isTiled = failed(cstTileSize) ||
ShapedType::isDynamic(srcDimSize) ||
cstTileSize.value() != srcDimSize;
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 7b0a8494a8acb..20164d5dfd91a 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -451,6 +451,51 @@ module attributes {transform.with_named_sequence} {
// -----
+
+func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> {
+ %0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
+ %src = tensor.extract_slice %arg0[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %dest = tensor.extract_slice %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %1 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %1 into %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+ }
+ }
+ %pack = linalg.pack %0 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] into %arg2 : tensor<64x32xf32> -> tensor<2x64x16x1xf32>
+ return %pack : tensor<2x64x16x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
+// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[ELEM_SRC]]
+// CHECK-SAME: outs(%[[ELEM_DEST]]
+// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
+// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
+
+// -----
+
// It is valid to fuse the pack op in perfect tiling scenario when the dimension
// is dynamic and padding is not needed.
More information about the Mlir-commits
mailing list