[Mlir-commits] [mlir] [mlir][linalg] Improve linalg.pack consumer fusion. (PR #148993)
Han-Chung Wang
llvmlistbot at llvm.org
Thu Jul 17 12:40:27 PDT 2025
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/148993
>From 3642d259c3ece69cbc41ab74af863d6b4b221839 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 15 Jul 2025 16:50:31 -0700
Subject: [PATCH 1/5] [mlir][linalg] Improve linalg.pack consumer fusion.
If a dimension is not tiled, it is always valid to to fuse the pack op
even if it has padding semantics. Because it always generates a full
slice along the dimension.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 41 +--
.../tile-and-fuse-consumer.mlir | 278 ++++++++++--------
2 files changed, 184 insertions(+), 135 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 513cecef29b61..fb9ba4ccb14af 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -887,26 +887,13 @@ struct PackOpTiling
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
ArrayRef<OpFoldResult> sizes(allSizes[0]);
-
auto packOp = cast<PackOp>(op);
- // It is not trivial to infer dest tile from source tile if `packOp` has
- // padding semantic.
- if (packOp.getPaddingValue())
- return failure();
-
Location loc = packOp.getLoc();
-
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
if (dimAndTileMapping.count(dim)) {
- FailureOr<int64_t> cstSize =
- ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, sizes[dim],
- /*stopCondition=*/nullptr, /*closedUB=*/true);
- std::optional<int64_t> cstInnerSize =
- getConstantIntValue(dimAndTileMapping[dim]);
// Currently fusing `packOp` as consumer only expects perfect tiling
// scenario because even if without padding semantic, the `packOp` may
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
@@ -916,12 +903,25 @@ struct PackOpTiling
// (0,0)~(0,4) at first row.
// 2. the second slice is extracted from (5) to (9) and SHOULD BE
// respectively inserted into two rows with different length, including
- // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
- // them, thus adding below constraint to bypass them temporarily. In
- // another word, we can only support tiling with consumer if the tile
- // size for the producer is a multiple of the inner tile size for the
- // packed dimensions at this moment.
- if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
+ // first row: (0,5) and second row (1,0)~(1,3).
+ // It is hard to coordinate them, thus adding below constraint to bypass
+ // them temporarily. In another word, we can only support tiling with
+ // consumer if the tile size for the producer is either a multiple of
+ // the inner tile size for the packed dimensions or the dimension is not
+ // tiled at this moment.
+ FailureOr<int64_t> cstTileSize =
+ ValueBoundsConstraintSet::computeConstantBound(
+ presburger::BoundType::UB, sizes[dim],
+ /*stopCondition=*/nullptr, /*closedUB=*/true);
+ std::optional<int64_t> cstInnerSize =
+ getConstantIntValue(dimAndTileMapping[dim]);
+ int64_t dimSize = packOp.getSourceType().getDimSize(dim);
+ // TODO: It could be untiled if the `dimSize` is dynamic. It is a hard
+ // check to determine if a dimension is tiled or not.
+ bool isTiled = failed(cstTileSize) || ShapedType::isDynamic(dimSize) ||
+ cstTileSize.value() != dimSize;
+ if (isTiled && (failed(cstTileSize) || !cstInnerSize ||
+ *cstTileSize % *cstInnerSize != 0)) {
return failure();
}
@@ -988,7 +988,8 @@ struct PackOpTiling
loc, packOp.getDest(), outputOffsets, outputSizes, strides);
tiledOperands.push_back(outSlice);
- assert(!packOp.getPaddingValue() && "Expect no padding semantic");
+ if (auto val = packOp.getPaddingValue())
+ tiledOperands.push_back(val);
for (auto tile : packOp.getInnerTiles())
tiledOperands.push_back(tile);
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index d09373bdb3f14..da3592547e125 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -193,33 +193,33 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
- %c4 = arith.constant 4 : index
- %c64 = arith.constant 64 : index
- %c0 = arith.constant 0 : index
- %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) {
- %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
- %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
- %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
- tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
- }
+ func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
}
- %1 = tensor.empty() : tensor<64x64xf32>
- %2 = tensor.empty() : tensor<64x64xf32>
- %3 = tensor.empty() : tensor<64x64xf32>
- %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32):
- %6 = arith.mulf %in, %in_0 : f32
- %7 = arith.subf %out, %6 : f32
- %8 = arith.addf %out_1, %in : f32
- linalg.yield %7, %8 : f32, f32
- } -> (tensor<64x64xf32>, tensor<64x64xf32>)
- %5 = tensor.empty() : tensor<2048xf32>
- %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32>
- return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32>
}
+ %1 = tensor.empty() : tensor<64x64xf32>
+ %2 = tensor.empty() : tensor<64x64xf32>
+ %3 = tensor.empty() : tensor<64x64xf32>
+ %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32):
+ %6 = arith.mulf %in, %in_0 : f32
+ %7 = arith.subf %out, %6 : f32
+ %8 = arith.addf %out_1, %in : f32
+ linalg.yield %7, %8 : f32, f32
+ } -> (tensor<64x64xf32>, tensor<64x64xf32>)
+ %5 = tensor.empty() : tensor<2048xf32>
+ %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32>
+ return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32>
+ }
}
module attributes {transform.with_named_sequence} {
@@ -269,38 +269,38 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
- %c4 = arith.constant 4 : index
- %c64 = arith.constant 64 : index
- %c0 = arith.constant 0 : index
- %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
- %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
- %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
- ^bb0(%in: f32, %in_16: f32, %out: f32):
- %13 = arith.mulf %in, %in_16 : f32
- %14 = arith.addf %out, %13 : f32
- linalg.yield %14 : f32
- } -> tensor<32x32xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
- }
- }
- %output = tensor.empty() : tensor<2048xf32>
- %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
- return %unpack : tensor<2048xf32>
+ func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
}
+ %output = tensor.empty() : tensor<2048xf32>
+ %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
+ return %unpack : tensor<2048xf32>
+ }
}
-
+
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %loop = transform.structured.match ops{["scf.forall"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
- : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
}
// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
@@ -332,38 +332,38 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
- %c4 = arith.constant 4 : index
- %c64 = arith.constant 64 : index
- %c0 = arith.constant 0 : index
- %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
- %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
- %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
- ^bb0(%in: f32, %in_16: f32, %out: f32):
- %13 = arith.mulf %in, %in_16 : f32
- %14 = arith.addf %out, %13 : f32
- linalg.yield %14 : f32
- } -> tensor<32x32xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
- }
- }
- %output = tensor.empty() : tensor<2047xf32>
- %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
- return %unpack : tensor<2047xf32>
+ func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
}
+ %output = tensor.empty() : tensor<2047xf32>
+ %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
+ return %unpack : tensor<2047xf32>
+ }
}
-
+
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %loop = transform.structured.match ops{["scf.forall"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
- : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
}
// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
@@ -395,38 +395,38 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
- %c4 = arith.constant 4 : index
- %c64 = arith.constant 64 : index
- %c0 = arith.constant 0 : index
- %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
- %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
- %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
- ^bb0(%in: f32, %in_16: f32, %out: f32):
- %13 = arith.mulf %in, %in_16 : f32
- %14 = arith.addf %out, %13 : f32
- linalg.yield %14 : f32
- } -> tensor<32x32xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
- }
- }
- %output = tensor.empty() : tensor<4x32x16xf32>
- %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
- return %pack : tensor<4x32x16xf32>
+ func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
}
+ %output = tensor.empty() : tensor<4x32x16xf32>
+ %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
+ return %pack : tensor<4x32x16xf32>
+ }
}
-
+
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
- %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %loop = transform.structured.match ops{["scf.forall"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
- : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
}
// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
// CHECK: func.func @fuse_pack_consumer_into_scf_forall(
@@ -451,6 +451,54 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
+ %0 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg0[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32>
+ %2 = linalg.exp ins(%extracted_slice : tensor<64x32xf32>) outs(%extracted_slice_0 : tensor<64x32xf32>) -> tensor<64x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> into tensor<64x32xf32>
+ }
+ }
+ %1 = tensor.empty() : tensor<23x32x3x16xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
+ return %pack : tensor<23x32x3x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK: func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x32x3x16xf32>
+// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) in (2)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 32] [1, 1]
+// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 32] [1, 1]
+// CHECK: %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[ELEM_SRC]]
+// CHECK-SAME: outs(%[[ELEM_DEST]]
+// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 2, 3, 16] [1, 1, 1, 1]
+// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 2, 3, 16] [1, 1, 1, 1]
+
+// -----
+
module {
func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
%c0 = arith.constant 0 : index
@@ -489,7 +537,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
-// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
// CHECK-SAME: {
// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
@@ -645,7 +693,7 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
scf.forall.in_parallel {
tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
- }
+ }
}
%empty = tensor.empty(%dim0) : tensor<?xf32>
%result = linalg.generic {
@@ -719,7 +767,7 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
scf.forall.in_parallel {
tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
- }
+ }
}
%empty = tensor.empty(%dim0) : tensor<?xf32>
%result = linalg.generic {
>From 061d4a2336d958d3bb83fd0ea64d7ce20b9cbc61 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 17 Jul 2025 11:15:31 -0700
Subject: [PATCH 2/5] Restrict the fusion condition.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 27 ++++++---
.../tile-and-fuse-consumer.mlir | 59 ++++++++++++++-----
2 files changed, 63 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index fb9ba4ccb14af..f609c65818e43 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "llvm/Support/Debug.h"
@@ -915,13 +916,25 @@ struct PackOpTiling
/*stopCondition=*/nullptr, /*closedUB=*/true);
std::optional<int64_t> cstInnerSize =
getConstantIntValue(dimAndTileMapping[dim]);
- int64_t dimSize = packOp.getSourceType().getDimSize(dim);
- // TODO: It could be untiled if the `dimSize` is dynamic. It is a hard
- // check to determine if a dimension is tiled or not.
- bool isTiled = failed(cstTileSize) || ShapedType::isDynamic(dimSize) ||
- cstTileSize.value() != dimSize;
- if (isTiled && (failed(cstTileSize) || !cstInnerSize ||
- *cstTileSize % *cstInnerSize != 0)) {
+ // If a dimension is not tiled, it is always valid to fuse the pack op,
+ // even if the op has padding semantics. Because it always generates a
+ // full slice along the dimension.
+ // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
+ // hard check to determine if a dimension is tiled or not.
+ int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
+ bool isTiled = failed(cstTileSize) ||
+ ShapedType::isDynamic(srcDimSize) ||
+ cstTileSize.value() != srcDimSize;
+ int64_t destDimSize = packOp.getDestType().getDimSize(dim);
+ bool needPadding = ShapedType::isDynamic(destDimSize) ||
+ !cstInnerSize ||
+ destDimSize * cstInnerSize.value() != srcDimSize;
+ // Prioritize the case that the op already says that it does not need
+ // padding.
+ if (!packOp.getPaddingValue()) {
+ needPadding = false;
+ }
+ if (isTiled && needPadding) {
return failure();
}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index da3592547e125..3d32ddd9bed84 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -451,19 +451,19 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
- %0 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
- %extracted_slice = tensor.extract_slice %arg0[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32>
- %extracted_slice_0 = tensor.extract_slice %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32>
- %2 = linalg.exp ins(%extracted_slice : tensor<64x32xf32>) outs(%extracted_slice_0 : tensor<64x32xf32>) -> tensor<64x32xf32>
+func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x2x3x16xf32> {
+ %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+ %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
scf.forall.in_parallel {
- tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> into tensor<64x32xf32>
+ tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
}
}
- %1 = tensor.empty() : tensor<23x32x3x16xf32>
+ %1 = tensor.empty() : tensor<23x2x3x16xf32>
%cst = arith.constant 0.000000e+00 : f32
- %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
- return %pack : tensor<23x32x3x16xf32>
+ %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x2x3x16xf32>
+ return %pack : tensor<23x2x3x16xf32>
}
module attributes {transform.with_named_sequence} {
@@ -478,24 +478,51 @@ module attributes {transform.with_named_sequence} {
// CHECK: func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x32x3x16xf32>
+// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x2x3x16xf32>
// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) in (2)
+// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
-// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 32] [1, 1]
-// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 32] [1, 1]
+// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
// CHECK: %[[ELEM:.*]] = linalg.exp
// CHECK-SAME: ins(%[[ELEM_SRC]]
// CHECK-SAME: outs(%[[ELEM_DEST]]
// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
-// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 2, 3, 16] [1, 1, 1, 1]
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
// CHECK-SAME: into %[[TILED_PACK_DEST]]
// CHECK: scf.forall.in_parallel {
-// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 32] [1, 1]
-// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 2, 3, 16] [1, 1, 1, 1]
+// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
+
+// -----
+
+func.func @nofuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
+ %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+ %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+ scf.forall.in_parallel {
+ // expected-error @below {{failed to fuse consumer of slice}}
+ tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+ }
+ }
+ %1 = tensor.empty() : tensor<23x32x3x16xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
+ return %pack : tensor<23x32x3x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
// -----
>From 864a9a55c6fafea9ff41f10e9ed757bfce0409cc Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 17 Jul 2025 11:16:35 -0700
Subject: [PATCH 3/5] Fix IR bug
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Interfaces/TilingInterface/tile-and-fuse-consumer.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 3d32ddd9bed84..fc64733d7a887 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -399,7 +399,7 @@ module {
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
- %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+ %1 = scf.forall (%arg3, %arg4) in (2, 1) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
%extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
^bb0(%in: f32, %in_16: f32, %out: f32):
@@ -434,7 +434,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32>
-// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 1)
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
// CHECK-SAME: {
// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
>From 583639db7bacdc197623df0e20185b865ffa095d Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 17 Jul 2025 12:24:57 -0700
Subject: [PATCH 4/5] Fix the fusion logic and add more lit tests
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 59 ++++++----
.../tile-and-fuse-consumer.mlir | 107 ++++++++++++++++--
2 files changed, 137 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index f609c65818e43..0513fbfe28148 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -895,48 +895,63 @@ struct PackOpTiling
packOp.getDimAndTileMapping();
for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
if (dimAndTileMapping.count(dim)) {
- // Currently fusing `packOp` as consumer only expects perfect tiling
- // scenario because even if without padding semantic, the `packOp` may
- // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
- // where the `tileSize` from operand of `packOp` is 5, which is not
- // exactly divided by `innerTile`(=6) of `packOp`. As the result:
- // 1. the first slice is extracted from (0) to (4) and inserted into
- // (0,0)~(0,4) at first row.
- // 2. the second slice is extracted from (5) to (9) and SHOULD BE
- // respectively inserted into two rows with different length, including
- // first row: (0,5) and second row (1,0)~(1,3).
- // It is hard to coordinate them, thus adding below constraint to bypass
- // them temporarily. In another word, we can only support tiling with
- // consumer if the tile size for the producer is either a multiple of
- // the inner tile size for the packed dimensions or the dimension is not
- // tiled at this moment.
FailureOr<int64_t> cstTileSize =
ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB, sizes[dim],
/*stopCondition=*/nullptr, /*closedUB=*/true);
std::optional<int64_t> cstInnerSize =
getConstantIntValue(dimAndTileMapping[dim]);
+
// If a dimension is not tiled, it is always valid to fuse the pack op,
// even if the op has padding semantics. Because it always generates a
// full slice along the dimension.
// TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
// hard check to determine if a dimension is tiled or not.
int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
+ int64_t destDimSize = packOp.getDestType().getDimSize(dim);
bool isTiled = failed(cstTileSize) ||
ShapedType::isDynamic(srcDimSize) ||
cstTileSize.value() != srcDimSize;
- int64_t destDimSize = packOp.getDestType().getDimSize(dim);
- bool needPadding = ShapedType::isDynamic(destDimSize) ||
+ if (!isTiled) {
+ outerDimOffsets.push_back(offsets[dim]);
+ if (ShapedType::isStatic(destDimSize)) {
+ outerDimSizes.push_back(b.getIndexAttr(destDimSize));
+ } else {
+ outerDimSizes.push_back(
+ b.createOrFold<tensor::DimOp>(loc, packOp.getDest(), dim));
+ }
+ continue;
+ }
+
+ // If the dimension needs padding, it is not supported because there are
+ // iterations that only write padding values to the whole tile. The
+ // consumer fusion is driven by the source, so it is not possible to map
+ // an empty slice to the tile.
+ bool needExtraPadding = ShapedType::isDynamic(destDimSize) ||
!cstInnerSize ||
destDimSize * cstInnerSize.value() != srcDimSize;
// Prioritize the case that the op already says that it does not need
// padding.
- if (!packOp.getPaddingValue()) {
- needPadding = false;
- }
- if (isTiled && needPadding) {
+ if (!packOp.getPaddingValue())
+ needExtraPadding = false;
+ if (needExtraPadding)
+ return failure();
+
+ // Currently fusing `packOp` as consumer only expects perfect tiling
+ // scenario because even if without padding semantic, the `packOp` may
+ // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
+ // where the `tileSize` from operand of `packOp` is 5, which is not
+ // exactly divided by `innerTile`(=6) of `packOp`. As the result:
+ // 1. the first slice is extracted from (0) to (4) and inserted into
+ // (0,0)~(0,4) at first row.
+ // 2. the second slice is extracted from (5) to (9) and SHOULD BE
+ // respectively inserted into two rows with different length, including
+ // first row: (0,5) and second row (1,0)~(1,3).
+ // It is hard to coordinate them, thus adding below constraint to bypass
+ // them temporarily.
+ if ((failed(cstTileSize) || !cstInnerSize ||
+ *cstTileSize % *cstInnerSize != 0))
return failure();
- }
using AV = affine::AffineValueExpr;
affine::AffineBuilder ab(b, loc);
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index fc64733d7a887..daa8341ca5a28 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -395,7 +395,7 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
+ func.func @fuse_perfect_tiling_pack_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
@@ -429,7 +429,7 @@ module attributes {transform.with_named_sequence} {
}
}
// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
-// CHECK: func.func @fuse_pack_consumer_into_scf_forall(
+// CHECK: func.func @fuse_perfect_tiling_pack_consumer(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
@@ -451,7 +451,10 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x2x3x16xf32> {
+// It is valid to fuse the pack op with padding semantics if the dimension does
+// not need padding.
+
+func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x2x3x16xf32> {
%0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
%src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
%dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
@@ -475,7 +478,7 @@ module attributes {transform.with_named_sequence} {
}
}
// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
-// CHECK: func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(
+// CHECK: func.func @fuse_pack_consumer_with_padding_semantics(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x2x3x16xf32>
@@ -488,18 +491,72 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: ins(%[[ELEM_SRC]]
// CHECK-SAME: outs(%[[ELEM_DEST]]
// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
-// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [23, 1, 3, 16] [1, 1, 1, 1]
// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
// CHECK-SAME: into %[[TILED_PACK_DEST]]
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
-// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
+// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [23, 1, 3, 16] [1, 1, 1, 1]
// -----
-func.func @nofuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
+// It is valid to fuse the pack if the dimension is not tiled even when it needs
+// extra padding.
+
+func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> {
+ %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+ %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+ }
+ }
+ %1 = tensor.empty() : tensor<33x2x3x16xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32>
+ return %pack : tensor<33x2x3x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32>
+// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[ELEM_SRC]]
+// CHECK-SAME: outs(%[[ELEM_DEST]]
+// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
+// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
+
+// -----
+
+// If the dimension is tiled and it needs extra padding, do not fuse the pack
+// op.
+
+func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
%0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
%src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
%dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
@@ -526,6 +583,42 @@ module attributes {transform.with_named_sequence} {
// -----
+// Imperfect tiling is not supported in pack op consumer fusion.
+
+#map = affine_map<(d0) -> (d0 * 5)>
+#map1 = affine_map<(d0) -> (d0)>
+func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x6xf32> {
+ %0 = tensor.empty() : tensor<30xf32>
+ %1 = scf.forall (%arg1) in (6) shared_outs(%arg2 = %0) -> (tensor<30xf32>) {
+ %3 = affine.apply #map(%arg1)
+ %extracted_slice = tensor.extract_slice %arg0[%3] [5] [1] : tensor<30xf32> to tensor<5xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg2[%3] [5] [1] : tensor<30xf32> to tensor<5xf32>
+ %4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<5xf32>) outs(%extracted_slice_0 : tensor<5xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %5 = arith.addf %in, %in : f32
+ linalg.yield %5 : f32
+ } -> tensor<5xf32>
+ scf.forall.in_parallel {
+ // expected-error @below {{failed to fuse consumer of slice}}
+ tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32>
+ }
+ }
+ %2 = tensor.empty() : tensor<5x6xf32>
+ %pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32>
+ return %pack : tensor<5x6xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
module {
func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
%c0 = arith.constant 0 : index
>From a5305bdde964b336273156b5e5b501618847af9d Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 17 Jul 2025 12:36:01 -0700
Subject: [PATCH 5/5] Recover the comment and add one more test.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 8 +--
.../tile-and-fuse-consumer.mlir | 49 +++++++++++++++++++
2 files changed, 54 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 0513fbfe28148..bc3e71d3b9b6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -946,9 +946,11 @@ struct PackOpTiling
// (0,0)~(0,4) at first row.
// 2. the second slice is extracted from (5) to (9) and SHOULD BE
// respectively inserted into two rows with different length, including
- // first row: (0,5) and second row (1,0)~(1,3).
- // It is hard to coordinate them, thus adding below constraint to bypass
- // them temporarily.
+ // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
+ // them, thus adding below constraint to bypass them temporarily. In
+ // another word, we can only support tiling with consumer if the tile
+ // size for the producer is a multiple of the inner tile size for the
+ // packed dimensions at this moment.
if ((failed(cstTileSize) || !cstInnerSize ||
*cstTileSize % *cstInnerSize != 0))
return failure();
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index daa8341ca5a28..ef9b9454c946e 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -451,6 +451,55 @@ module attributes {transform.with_named_sequence} {
// -----
+// It is valid to fuse the pack op in perfect tiling scenario when the dimension
+// is dynamic and padding is not needed.
+
+func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(%arg0: tensor<64x?xf32>, %arg1: tensor<64x?xf32>, %1: tensor<64x?x16xf32>) -> tensor<64x?x16xf32> {
+ %c1 = arith.constant 1 : index
+ %d1 = tensor.dim %arg0, %c1 : tensor<64x?xf32>
+ %0 = scf.forall (%arg2) = (0) to (%d1) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x?xf32>) {
+ %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32>
+ %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32>
+ %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x?xf32>
+ }
+ }
+ %pack = linalg.pack %0 inner_dims_pos = [1] inner_tiles = [16] into %1 : tensor<64x?xf32> -> tensor<64x?x16xf32>
+ return %pack : tensor<64x?x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK: func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (%{{.+}}) step (16)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
+// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[ELEM_SRC]]
+// CHECK-SAME: outs(%[[ELEM_DEST]]
+// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1]
+// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1]
+
+// -----
+
// It is valid to fuse the pack op with padding semantics if the dimension does
// not need padding.
More information about the Mlir-commits
mailing list