[Mlir-commits] [mlir] Fix `transpose->unpack` folding pattern for the partial-tile case of `unpack` (PR #107271)
Benoit Jacob
llvmlistbot at llvm.org
Wed Sep 4 11:14:27 PDT 2024
https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/107271
>From 20c28821fdb29fb4679e638c7d260abaf94b3147 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 4 Sep 2024 13:02:12 -0400
Subject: [PATCH] fold-transpose-unpack-partial-tile
Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
.../Transforms/PackAndUnpackPatterns.cpp | 13 ++++---
.../Tensor/fold-into-pack-and-unpack.mlir | 35 +++++++++++++++----
2 files changed, 38 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index c681cadcb27cb2..995486c87771a3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -439,6 +439,11 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
if (failed(maybePerm))
return failure();
+ SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
+ if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
+ return failure();
+ }
+
SmallVector<int64_t> inverseTransposePerm =
invertPermutationVector(maybePerm.value());
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
@@ -448,7 +453,6 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
SmallVector<int64_t> newOuterDimsPermVec;
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
-
if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
newOuterDimsPermVec, destRank))
return rewriter.notifyMatchFailure(
@@ -463,9 +467,10 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
}
- Value output = unPackOp.createDestinationTensor(
- rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
- newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+ auto elemType =
+ cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
+ Value output = rewriter.create<tensor::EmptyOp>(
+ unPackOp->getLoc(), unpackOpResultDims[0], elemType);
rewriter.replaceOpWithNewOp<UnPackOp>(
unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 629a4c21357207..bff913f5f55feb 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -550,6 +550,32 @@ func.func @linalg_transpose_tensor_unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t
// -----
+func.func @linalg_transpose_tensor_unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
+ %0 = tensor.empty() : tensor<1x1x16x4xi32>
+ %transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
+ outs(%0 : tensor<1x1x16x4xi32>)
+ permutation = [1, 0, 3, 2]
+ %1 = tensor.empty() : tensor<15x3xi32>
+ %unpack = tensor.unpack %transposed
+ outer_dims_perm = [0, 1]
+ inner_dims_pos = [0, 1]
+ inner_tiles = [16, 4] into
+ %1 : tensor<1x1x16x4xi32> -> tensor<15x3xi32>
+ return %unpack : tensor<15x3xi32>
+}
+//CHECK-LABEL: func.func @linalg_transpose_tensor_unpack_fold_partial_tile(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<15x3xi32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [1, 0]
+// CHECK-SAME: inner_dims_pos = [1, 0]
+// CHECK-SAME: inner_tiles = [4, 16]
+// CHECK-SAME: into %[[OUT]] : tensor<1x1x4x16xi32> -> tensor<15x3xi32>
+// CHECK: return %[[UNPACK]] : tensor<15x3xi32>
+// CHECK: }
+
+// -----
+
func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %unpack_dest: tensor<?x?xf32>, %tile_p : index, %tile_q : index) -> tensor<?x?xf32> {
%transposed = linalg.transpose
ins(%arg0 : tensor<?x?x?x?xf32>)
@@ -563,17 +589,14 @@ func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile
into %unpack_dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
return %unpack : tensor<?x?xf32>
}
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK-LABEL: func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>,
// CHECK-SAME: %[[IDX1:.+]]: index, %[[IDX2:.+]]: index) -> tensor<?x?xf32> {
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[CST0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST0]] : tensor<?x?x?x?xf32>
-// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST1]] : tensor<?x?x?x?xf32>
-// CHECK-DAG: %[[AMAP0:.+]] = affine.apply #[[$MAP]]()[%[[DIM1]], %[[IDX2]]]
-// CHECK-DAG: %[[AMAP1:.+]] = affine.apply #[[$MAP]]()[%[[DIM0]], %[[IDX1]]]
-// CHECK: %[[OUT:.+]] = tensor.empty(%[[AMAP1]], %[[AMAP0]]) : tensor<?x?xf32>
+// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[CST0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[CST1]] : tensor<?x?xf32>
+// CHECK: %[[OUT:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 1]
// CHECK-SAME: inner_dims_pos = [1, 0]
More information about the Mlir-commits
mailing list