[Mlir-commits] [mlir] Fix `transpose->unpack` folding pattern for the partial-tile case of `unpack` (PR #107271)

Benoit Jacob llvmlistbot at llvm.org
Wed Sep 4 11:14:27 PDT 2024


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/107271

>From 20c28821fdb29fb4679e638c7d260abaf94b3147 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 4 Sep 2024 13:02:12 -0400
Subject: [PATCH] fold-transpose-unpack-partial-tile

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
 .../Transforms/PackAndUnpackPatterns.cpp      | 13 ++++---
 .../Tensor/fold-into-pack-and-unpack.mlir     | 35 +++++++++++++++----
 2 files changed, 38 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index c681cadcb27cb2..995486c87771a3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -439,6 +439,11 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     if (failed(maybePerm))
       return failure();
 
+    SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
+    if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
+      return failure();
+    }
+
     SmallVector<int64_t> inverseTransposePerm =
         invertPermutationVector(maybePerm.value());
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
@@ -448,7 +453,6 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     SmallVector<int64_t> newOuterDimsPermVec;
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
-
     if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
                          newOuterDimsPermVec, destRank))
       return rewriter.notifyMatchFailure(
@@ -463,9 +467,10 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
       newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
     }
 
-    Value output = unPackOp.createDestinationTensor(
-        rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
-        newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+    auto elemType =
+        cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
+    Value output = rewriter.create<tensor::EmptyOp>(
+        unPackOp->getLoc(), unpackOpResultDims[0], elemType);
 
     rewriter.replaceOpWithNewOp<UnPackOp>(
         unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 629a4c21357207..bff913f5f55feb 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -550,6 +550,32 @@ func.func @linalg_transpose_tensor_unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t
 
 // -----
 
+func.func @linalg_transpose_tensor_unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
+  %0 = tensor.empty() : tensor<1x1x16x4xi32>
+  %transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
+                outs(%0 : tensor<1x1x16x4xi32>)
+                permutation = [1, 0, 3, 2]
+  %1 = tensor.empty() : tensor<15x3xi32>
+  %unpack = tensor.unpack %transposed
+            outer_dims_perm = [0, 1]
+            inner_dims_pos = [0, 1]
+            inner_tiles = [16, 4] into
+            %1 : tensor<1x1x16x4xi32> -> tensor<15x3xi32>
+  return %unpack : tensor<15x3xi32>
+}
+//CHECK-LABEL:  func.func @linalg_transpose_tensor_unpack_fold_partial_tile(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
+//      CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<15x3xi32>
+//      CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:        outer_dims_perm = [1, 0]
+// CHECK-SAME:        inner_dims_pos = [1, 0]
+// CHECK-SAME:        inner_tiles = [4, 16]
+// CHECK-SAME:        into %[[OUT]] : tensor<1x1x4x16xi32> -> tensor<15x3xi32>
+//      CHECK:     return %[[UNPACK]] : tensor<15x3xi32>
+//      CHECK:   }
+
+// -----
+
 func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %unpack_dest: tensor<?x?xf32>, %tile_p : index, %tile_q : index) -> tensor<?x?xf32> {
   %transposed = linalg.transpose
     ins(%arg0 : tensor<?x?x?x?xf32>)
@@ -563,17 +589,14 @@ func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile
     into %unpack_dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
   return %unpack : tensor<?x?xf32>
 }
-//       CHECK:    #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
 // CHECK-LABEL:   func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>,
 //  CHECK-SAME:     %[[IDX1:.+]]: index, %[[IDX2:.+]]: index) -> tensor<?x?xf32> {
 //   CHECK-DAG:       %[[CST1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:       %[[CST0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:       %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST0]] : tensor<?x?x?x?xf32>
-//   CHECK-DAG:       %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST1]] : tensor<?x?x?x?xf32>
-//   CHECK-DAG:       %[[AMAP0:.+]] = affine.apply #[[$MAP]]()[%[[DIM1]], %[[IDX2]]]
-//   CHECK-DAG:       %[[AMAP1:.+]] = affine.apply #[[$MAP]]()[%[[DIM0]], %[[IDX1]]]
-//       CHECK:       %[[OUT:.+]] = tensor.empty(%[[AMAP1]], %[[AMAP0]]) : tensor<?x?xf32>
+//   CHECK-DAG:       %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[CST0]] : tensor<?x?xf32>
+//   CHECK-DAG:       %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[CST1]] : tensor<?x?xf32>
+//       CHECK:       %[[OUT:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?xf32>
 //       CHECK:       %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
 //  CHECK-SAME:         outer_dims_perm = [0, 1]
 //  CHECK-SAME:         inner_dims_pos = [1, 0]



More information about the Mlir-commits mailing list