[Mlir-commits] [mlir] Fix `transpose->unpack` folding pattern for the partial-tile case of `unpack` (PR #107271)
Benoit Jacob
llvmlistbot at llvm.org
Wed Sep 4 10:07:51 PDT 2024
https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/107271
`UnPackOp::createDestinationTensor` was trying to infer the destination shape, which wasn't possible in general with the set of paramters that it was taking, in the case of partial-tile `unpack` where `unpack` has extract-slice semantics.
Added an optional (default empty) additional parameter to `UnPackOp::createDestinationTensor` to allow passing the destination shape. Went over existing callers. Only one needed to pass it explicitly, others are in the full-tile case where the existing code was fine.
>From c80fb599c325345b39f6f92264d1394fe88ba40d Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 4 Sep 2024 13:02:12 -0400
Subject: [PATCH] fold-transpose-unpack-partial-tile
---
.../mlir/Dialect/Tensor/IR/TensorOps.td | 3 +-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 13 ++++---
.../Transforms/PackAndUnpackPatterns.cpp | 12 +++++--
.../Tensor/fold-into-pack-and-unpack.mlir | 35 +++++++++++++++----
4 files changed, 48 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..8040cc97cd8bc4 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -2076,7 +2076,8 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
let extraClassDeclaration = commonExtraClassDeclaration # [{
static Value createDestinationTensor(OpBuilder &b, Location loc,
Value source, ArrayRef<OpFoldResult> innerTileSizes,
- ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+ ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
+ SmallVector<OpFoldResult> mixedSizes = {});
/// Build and return a new UnPackOp that is a clone of the current UnPackOp
/// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 996de530c255d4..41afbbe840352c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4360,15 +4360,19 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
Value source,
ArrayRef<OpFoldResult> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm) {
+ ArrayRef<int64_t> outerDimsPerm,
+ SmallVector<OpFoldResult> mixedSizes) {
+ auto srcType = llvm::cast<RankedTensorType>(source.getType());
+ auto elemType = srcType.getElementType();
+ if (!mixedSizes.empty()) {
+ return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
+ }
+
AffineExpr sym0, sym1;
bindSymbols(b.getContext(), sym0, sym1);
auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
};
-
- SmallVector<OpFoldResult> mixedSizes;
- auto srcType = llvm::cast<RankedTensorType>(source.getType());
for (auto i :
llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
if (srcType.isDynamicDim(i))
@@ -4384,7 +4388,6 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
- auto elemType = srcType.getElementType();
return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index c681cadcb27cb2..fdd6ff47f3bb5e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -439,6 +439,11 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
if (failed(maybePerm))
return failure();
+ SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
+ if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
+ return failure();
+ }
+
SmallVector<int64_t> inverseTransposePerm =
invertPermutationVector(maybePerm.value());
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
@@ -448,13 +453,13 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
SmallVector<int64_t> newOuterDimsPermVec;
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
-
if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
- newOuterDimsPermVec, destRank))
+ newOuterDimsPermVec, destRank)) {
return rewriter.notifyMatchFailure(
unPackOp,
"Cannot fold in tensor.unpack if a tile dimension was transposed "
"with a non-tile dimension in linalg.transpose.");
+ }
// Process transpose operation for tiled inner dimensions
for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
@@ -465,7 +470,8 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
Value output = unPackOp.createDestinationTensor(
rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
- newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+ newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec,
+ unpackOpResultDims[0]);
rewriter.replaceOpWithNewOp<UnPackOp>(
unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 629a4c21357207..bff913f5f55feb 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -550,6 +550,32 @@ func.func @linalg_transpose_tensor_unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t
// -----
+func.func @linalg_transpose_tensor_unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
+ %0 = tensor.empty() : tensor<1x1x16x4xi32>
+ %transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
+ outs(%0 : tensor<1x1x16x4xi32>)
+ permutation = [1, 0, 3, 2]
+ %1 = tensor.empty() : tensor<15x3xi32>
+ %unpack = tensor.unpack %transposed
+ outer_dims_perm = [0, 1]
+ inner_dims_pos = [0, 1]
+ inner_tiles = [16, 4] into
+ %1 : tensor<1x1x16x4xi32> -> tensor<15x3xi32>
+ return %unpack : tensor<15x3xi32>
+}
+//CHECK-LABEL: func.func @linalg_transpose_tensor_unpack_fold_partial_tile(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<15x3xi32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [1, 0]
+// CHECK-SAME: inner_dims_pos = [1, 0]
+// CHECK-SAME: inner_tiles = [4, 16]
+// CHECK-SAME: into %[[OUT]] : tensor<1x1x4x16xi32> -> tensor<15x3xi32>
+// CHECK: return %[[UNPACK]] : tensor<15x3xi32>
+// CHECK: }
+
+// -----
+
func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %unpack_dest: tensor<?x?xf32>, %tile_p : index, %tile_q : index) -> tensor<?x?xf32> {
%transposed = linalg.transpose
ins(%arg0 : tensor<?x?x?x?xf32>)
@@ -563,17 +589,14 @@ func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile
into %unpack_dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
return %unpack : tensor<?x?xf32>
}
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK-LABEL: func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>,
// CHECK-SAME: %[[IDX1:.+]]: index, %[[IDX2:.+]]: index) -> tensor<?x?xf32> {
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[CST0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST0]] : tensor<?x?x?x?xf32>
-// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST1]] : tensor<?x?x?x?xf32>
-// CHECK-DAG: %[[AMAP0:.+]] = affine.apply #[[$MAP]]()[%[[DIM1]], %[[IDX2]]]
-// CHECK-DAG: %[[AMAP1:.+]] = affine.apply #[[$MAP]]()[%[[DIM0]], %[[IDX1]]]
-// CHECK: %[[OUT:.+]] = tensor.empty(%[[AMAP1]], %[[AMAP0]]) : tensor<?x?xf32>
+// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[CST0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[CST1]] : tensor<?x?xf32>
+// CHECK: %[[OUT:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 1]
// CHECK-SAME: inner_dims_pos = [1, 0]
More information about the Mlir-commits
mailing list