[Mlir-commits] [mlir] c1667f9 - Fix `transpose->unpack` folding pattern for the partial-tile case of `unpack` (#107271)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 4 12:06:30 PDT 2024


Author: Benoit Jacob
Date: 2024-09-04T15:06:27-04:00
New Revision: c1667f909949d15c593e4a03a4e992cffa72ad3c

URL: https://github.com/llvm/llvm-project/commit/c1667f909949d15c593e4a03a4e992cffa72ad3c
DIFF: https://github.com/llvm/llvm-project/commit/c1667f909949d15c593e4a03a4e992cffa72ad3c.diff

LOG: Fix `transpose->unpack` folding pattern for the partial-tile case of `unpack` (#107271)

Just directly create the empty tensor of appropriate shape instead of
relying on `UnPackOp::createDestinationTensor` which is trying to infer
the destination shape, which isn't possible in general with the set of
paramters that it is taking.

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
    mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index c681cadcb27cb2..995486c87771a3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -439,6 +439,11 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     if (failed(maybePerm))
       return failure();
 
+    SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
+    if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
+      return failure();
+    }
+
     SmallVector<int64_t> inverseTransposePerm =
         invertPermutationVector(maybePerm.value());
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
@@ -448,7 +453,6 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     SmallVector<int64_t> newOuterDimsPermVec;
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
-
     if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
                          newOuterDimsPermVec, destRank))
       return rewriter.notifyMatchFailure(
@@ -463,9 +467,10 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
       newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
     }
 
-    Value output = unPackOp.createDestinationTensor(
-        rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
-        newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+    auto elemType =
+        cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
+    Value output = rewriter.create<tensor::EmptyOp>(
+        unPackOp->getLoc(), unpackOpResultDims[0], elemType);
 
     rewriter.replaceOpWithNewOp<UnPackOp>(
         unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,

diff  --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 629a4c21357207..bff913f5f55feb 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -550,6 +550,32 @@ func.func @linalg_transpose_tensor_unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t
 
 // -----
 
+func.func @linalg_transpose_tensor_unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
+  %0 = tensor.empty() : tensor<1x1x16x4xi32>
+  %transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
+                outs(%0 : tensor<1x1x16x4xi32>)
+                permutation = [1, 0, 3, 2]
+  %1 = tensor.empty() : tensor<15x3xi32>
+  %unpack = tensor.unpack %transposed
+            outer_dims_perm = [0, 1]
+            inner_dims_pos = [0, 1]
+            inner_tiles = [16, 4] into
+            %1 : tensor<1x1x16x4xi32> -> tensor<15x3xi32>
+  return %unpack : tensor<15x3xi32>
+}
+//CHECK-LABEL:  func.func @linalg_transpose_tensor_unpack_fold_partial_tile(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
+//      CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<15x3xi32>
+//      CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:        outer_dims_perm = [1, 0]
+// CHECK-SAME:        inner_dims_pos = [1, 0]
+// CHECK-SAME:        inner_tiles = [4, 16]
+// CHECK-SAME:        into %[[OUT]] : tensor<1x1x4x16xi32> -> tensor<15x3xi32>
+//      CHECK:     return %[[UNPACK]] : tensor<15x3xi32>
+//      CHECK:   }
+
+// -----
+
 func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %unpack_dest: tensor<?x?xf32>, %tile_p : index, %tile_q : index) -> tensor<?x?xf32> {
   %transposed = linalg.transpose
     ins(%arg0 : tensor<?x?x?x?xf32>)
@@ -563,17 +589,14 @@ func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile
     into %unpack_dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
   return %unpack : tensor<?x?xf32>
 }
-//       CHECK:    #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
 // CHECK-LABEL:   func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>,
 //  CHECK-SAME:     %[[IDX1:.+]]: index, %[[IDX2:.+]]: index) -> tensor<?x?xf32> {
 //   CHECK-DAG:       %[[CST1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:       %[[CST0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:       %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST0]] : tensor<?x?x?x?xf32>
-//   CHECK-DAG:       %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST1]] : tensor<?x?x?x?xf32>
-//   CHECK-DAG:       %[[AMAP0:.+]] = affine.apply #[[$MAP]]()[%[[DIM1]], %[[IDX2]]]
-//   CHECK-DAG:       %[[AMAP1:.+]] = affine.apply #[[$MAP]]()[%[[DIM0]], %[[IDX1]]]
-//       CHECK:       %[[OUT:.+]] = tensor.empty(%[[AMAP1]], %[[AMAP0]]) : tensor<?x?xf32>
+//   CHECK-DAG:       %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[CST0]] : tensor<?x?xf32>
+//   CHECK-DAG:       %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[CST1]] : tensor<?x?xf32>
+//       CHECK:       %[[OUT:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?xf32>
 //       CHECK:       %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
 //  CHECK-SAME:         outer_dims_perm = [0, 1]
 //  CHECK-SAME:         inner_dims_pos = [1, 0]


        


More information about the Mlir-commits mailing list