[llvm-branch-commits] [mlir] 410d9d0 - Revert "[mlir][linalg] Restrict linalg.pack to not have artificial padding. (…"
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jul 25 11:26:40 PDT 2025
Author: Han-Chung Wang
Date: 2025-07-25T11:26:37-07:00
New Revision: 410d9d069eeeb8a09f98a4b576b1d7d99db1b9b9
URL: https://github.com/llvm/llvm-project/commit/410d9d069eeeb8a09f98a4b576b1d7d99db1b9b9
DIFF: https://github.com/llvm/llvm-project/commit/410d9d069eeeb8a09f98a4b576b1d7d99db1b9b9.diff
LOG: Revert "[mlir][linalg] Restrict linalg.pack to not have artificial padding. (…"
This reverts commit 773e158c64735a80b814f20be6b959d9577531f8.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/data-layout-propagation.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/transform-lower-pack.mlir
mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index f36b41ccf6745..fa572024ff72b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -106,9 +106,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
result tensor in the order in which they appear, i.e.
`shape(result)[rank(result) + i] = inner_tiles[i]` for `0 <= i < k`.
- The following relationship for the tiled dimensions holds:
- `shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`,
- where (⌈/⌉ indicates CeilDiv).
-
+ `shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`.
Example: If `inner_tiles = [16, 32]`, the result tensor has a shape of
`...x16x32`. If `inner_dims_pos = [0, 1]`, the 0th source dimension is tiled
@@ -152,17 +150,9 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
`padding_value` specifies a padding value at the boundary on non-perfectly
divisible dimensions. Padding is optional:
- - If absent, it is assumed that for all inner tiles,
- `shape(source)[inner_dims_pos[i]] % inner_tiles[i] == 0`, i.e. all inner
- tiles divide perfectly the corresponding outer dimension in the result
- tensor. It is UB if the tile does not perfectly divide the dimension.
+ - If absent, it is UB if the tile does not perfectly divide the dimension.
- If present, it will pad along high dimensions (high-padding) to make the
- tile complete. Note that it is not allowed to have artificial padding that
- is not strictly required by linalg.pack (i.e., padding past what is needed
- to complete the last tile along each packed dimension). It is UB if extra
- padding is requested.
- It is not possible to verify the requirements statically with dynamic
- shapes, so they are treated as UB.
+ tile complete.
Example:
```mlir
@@ -177,15 +167,6 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
//
// Note: Only tiled dimensions can be padded.
```
-
- Invalid example that has artificial padding:
- ```mlir
- %0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0]
- inner_tiles = [8] into %dest
- : tensor<9xf32> -> tensor<3x8xf32>
- // \
- // expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
- ```
}];
let arguments = (ins AnyRankedTensor:$source,
AnyRankedTensor:$dest,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e25d063fce97b..4fee81aa2ef67 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -32,7 +32,6 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -4625,6 +4624,22 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
});
}
+/// Returns true if the dimension of `sourceShape` is smaller than the dimension
+/// of the `limitShape`.
+static bool areAllInBound(ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> limitShape) {
+ assert(
+ sourceShape.size() == limitShape.size() &&
+ "expected source shape rank, and limit of the shape to have same rank");
+ return llvm::all_of(
+ llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
+ int64_t sourceExtent = std::get<0>(it);
+ int64_t limit = std::get<1>(it);
+ return ShapedType::isDynamic(sourceExtent) ||
+ ShapedType::isDynamic(limit) || sourceExtent <= limit;
+ });
+}
+
template <typename OpTy>
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
@@ -4683,6 +4698,11 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// represents full tiles.
RankedTensorType expectedPackedType = PackOp::inferPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
+ if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
+ return op->emitError("the shape of output is not large enough to hold the "
+ "packed data. Expected at least ")
+ << expectedPackedType << ", got " << packedType;
+ }
if (!llvm::all_of(
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
@@ -4699,12 +4719,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
}
- if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
- packedType.getShape()))) {
- return op->emitError("expected ")
- << expectedPackedType << " for the packed domain value, got "
- << packedType;
- }
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 595d2625ee27c..a45a4e314e511 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -10,7 +10,6 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 39a7b1b1a2775..9cbb56e4de884 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1387,43 +1387,42 @@ func.func @recursive_effect(%arg : tensor<1xf32>) {
// CHECK-LABEL: @recursive_effect
// CHECK: linalg.map
-// -----
-
//===----------------------------------------------------------------------===//
// linalg.pack
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @fold_pack_constant_splat
// CHECK-NOT: linalg.pack
-// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
-func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
+// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
+func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
%cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
%0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
- inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32>
- return %0 : tensor<4x8x8x32xf32>
+ inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
+ return %0 : tensor<8x16x8x32xf32>
}
// -----
// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
// CHECK-NOT: linalg.pack
-// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
-func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
+// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
+func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
%pad = arith.constant 1.000000e-01 : f32
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
%0 = linalg.pack %cst
padding_value(%pad : f32)
outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
- inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
- return %0 : tensor<4x8x8x32xf32>
+ inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
+ return %0 : tensor<8x16x8x32xf32>
}
+
// -----
// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
// CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
// CHECK: linalg.pack
-func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
+func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
%pad = arith.constant 0.0 : f32
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
%0 = linalg.pack %cst
@@ -1431,8 +1430,8 @@ func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>
outer_dims_perm = [1, 0]
inner_dims_pos = [0, 1]
inner_tiles = [8, 32]
- into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
- return %0 : tensor<4x8x8x32xf32>
+ into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
+ return %0 : tensor<8x16x8x32xf32>
}
// -----
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index cc26fa48abf4b..6fc8d9f152f4e 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1295,6 +1295,24 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
// -----
+func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
+ %empty = tensor.empty() : tensor<8x4x16x8xf32>
+ %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
+ %pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
+ return %pack : tensor<8x4x16x8xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
+// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
+// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
+// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32>
+
+// -----
+
func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
%6 = tensor.empty(%dim) : tensor<?x256xf32>
%unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 40bf4d19d6b91..da1dfc7b6a624 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1760,7 +1760,6 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf
}
// -----
-
func.func @pack_mismatch_inner_tile_size_and_output_shape(
%input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
// expected-error at +1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
@@ -1825,47 +1824,27 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
// -----
-func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
- %cst = arith.constant 0.0 : f32
- // expected-error at +1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
- %0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0]
- inner_tiles = [8] into %output
- : tensor<9xf32> -> tensor<3x8xf32>
- return %0 : tensor<3x8xf32>
-}
-
-// -----
-
// The outer dims in the output tensor are incorrectly/unexpectedly transposed.
// This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
- // expected-error at +1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}}
+ // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
return %0 : tensor<4x16x32x16xf32>
}
// -----
-func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
- // expected-error at +1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}}
- %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
- return %0 : tensor<8x7x16x32xf32>
-}
-
-// -----
-
-func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
- // expected-error at +1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
- %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
- : tensor<3x8xf32> -> tensor<9xf32>
- return %0 : tensor<9xf32>
+func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+ // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
+ %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
+ return %0 : tensor<8x8x32x16xf32>
}
// -----
-func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
- // expected-error at +1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}}
- %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
+func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
+ // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
+ %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
}
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 9e7681d1a1b7d..81fd7a8a947d7 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: func.func @pack_with_pad(
-func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>)
- -> tensor<265x12x16x1xf32> {
+func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>)
+ -> tensor<265x16x16x1xf32> {
// CHECK: tensor.pad {{.*}} low[0, 0]
- // CHECK: : tensor<4225x12xf32> to tensor<4240x12xf32>
+ // CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32>
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
- // CHECK-SAME: : tensor<4240x12xf32> into tensor<265x16x12x1xf32>
+ // CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32>
// CHECK: linalg.transpose
- // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>)
- // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>)
+ // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
+ // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
// CHECK-SAME: permutation = [0, 2, 1, 3]
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.pack %src
padding_value(%cst : f32)
inner_dims_pos = [0, 1]
inner_tiles = [16, 1] into %dest
- : tensor<4225x12xf32> -> tensor<265x12x16x1xf32>
- return %0 : tensor<265x12x16x1xf32>
+ : tensor<4225x12xf32> -> tensor<265x16x16x1xf32>
+ return %0 : tensor<265x16x16x1xf32>
}
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index e48e5c6c308be..cdbca7228ded3 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -646,6 +646,87 @@ module attributes {transform.with_named_sequence} {
// -----
+// It is valid to fuse the pack if the dimension is not tiled even when it needs
+// extra padding.
+
+func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> {
+ %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+ %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+ }
+ }
+ %1 = tensor.empty() : tensor<33x2x3x16xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32>
+ return %pack : tensor<33x2x3x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32>
+// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[ELEM_SRC]]
+// CHECK-SAME: outs(%[[ELEM_DEST]]
+// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
+// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
+
+// -----
+
+// If the dimension is tiled and it needs extra padding, do not fuse the pack
+// op.
+
+func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
+ %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+ %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+ scf.forall.in_parallel {
+ // expected-error @below {{failed to fuse consumer of slice}}
+ tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+ }
+ }
+ %1 = tensor.empty() : tensor<23x32x3x16xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
+ return %pack : tensor<23x32x3x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
// Imperfect tiling is not supported in pack op consumer fusion.
#map = affine_map<(d0) -> (d0 * 5)>
More information about the llvm-branch-commits
mailing list