[Mlir-commits] [mlir] Fix `transpose->unpack` folding pattern for the partial-tile case of `unpack` (PR #107271)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 4 10:42:03 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Benoit Jacob (bjacob)

<details>
<summary>Changes</summary>

`UnPackOp::createDestinationTensor` was trying to infer the destination shape, which wasn't possible in general with the set of paramters that it was taking, in the case of partial-tile `unpack` where `unpack` has extract-slice semantics.

Added an optional (default empty) additional parameter to `UnPackOp::createDestinationTensor` to allow passing the destination shape. Went over existing callers. Only one needed to pass it explicitly, others are in the full-tile case where the existing code was fine.

---
Full diff: https://github.com/llvm/llvm-project/pull/107271.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+2-1) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+8-5) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+9-3) 
- (modified) mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir (+29-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..8040cc97cd8bc4 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -2076,7 +2076,8 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
-        ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+        ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
+        SmallVector<OpFoldResult> mixedSizes = {});
 
     /// Build and return a new UnPackOp that is a clone of the current UnPackOp
     /// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 996de530c255d4..41afbbe840352c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4360,15 +4360,19 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
                                         Value source,
                                         ArrayRef<OpFoldResult> innerTileSizes,
                                         ArrayRef<int64_t> innerDimsPos,
-                                        ArrayRef<int64_t> outerDimsPerm) {
+                                        ArrayRef<int64_t> outerDimsPerm,
+                                        SmallVector<OpFoldResult> mixedSizes) {
+  auto srcType = llvm::cast<RankedTensorType>(source.getType());
+  auto elemType = srcType.getElementType();
+  if (!mixedSizes.empty()) {
+    return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
+  }
+
   AffineExpr sym0, sym1;
   bindSymbols(b.getContext(), sym0, sym1);
   auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
     return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
   };
-
-  SmallVector<OpFoldResult> mixedSizes;
-  auto srcType = llvm::cast<RankedTensorType>(source.getType());
   for (auto i :
        llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
     if (srcType.isDynamicDim(i))
@@ -4384,7 +4388,6 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
   for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
     mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
 
-  auto elemType = srcType.getElementType();
   return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
 }
 
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index c681cadcb27cb2..fdd6ff47f3bb5e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -439,6 +439,11 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     if (failed(maybePerm))
       return failure();
 
+    SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
+    if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
+      return failure();
+    }
+
     SmallVector<int64_t> inverseTransposePerm =
         invertPermutationVector(maybePerm.value());
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
@@ -448,13 +453,13 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     SmallVector<int64_t> newOuterDimsPermVec;
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
-
     if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
-                         newOuterDimsPermVec, destRank))
+                         newOuterDimsPermVec, destRank)) {
       return rewriter.notifyMatchFailure(
           unPackOp,
           "Cannot fold in tensor.unpack if a tile dimension was transposed "
           "with a non-tile dimension in linalg.transpose.");
+    }
 
     // Process transpose operation for tiled inner dimensions
     for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
@@ -465,7 +470,8 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
 
     Value output = unPackOp.createDestinationTensor(
         rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
-        newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+        newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec,
+        unpackOpResultDims[0]);
 
     rewriter.replaceOpWithNewOp<UnPackOp>(
         unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 629a4c21357207..bff913f5f55feb 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -550,6 +550,32 @@ func.func @linalg_transpose_tensor_unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t
 
 // -----
 
+func.func @linalg_transpose_tensor_unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
+  %0 = tensor.empty() : tensor<1x1x16x4xi32>
+  %transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
+                outs(%0 : tensor<1x1x16x4xi32>)
+                permutation = [1, 0, 3, 2]
+  %1 = tensor.empty() : tensor<15x3xi32>
+  %unpack = tensor.unpack %transposed
+            outer_dims_perm = [0, 1]
+            inner_dims_pos = [0, 1]
+            inner_tiles = [16, 4] into
+            %1 : tensor<1x1x16x4xi32> -> tensor<15x3xi32>
+  return %unpack : tensor<15x3xi32>
+}
+//CHECK-LABEL:  func.func @linalg_transpose_tensor_unpack_fold_partial_tile(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
+//      CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<15x3xi32>
+//      CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:        outer_dims_perm = [1, 0]
+// CHECK-SAME:        inner_dims_pos = [1, 0]
+// CHECK-SAME:        inner_tiles = [4, 16]
+// CHECK-SAME:        into %[[OUT]] : tensor<1x1x4x16xi32> -> tensor<15x3xi32>
+//      CHECK:     return %[[UNPACK]] : tensor<15x3xi32>
+//      CHECK:   }
+
+// -----
+
 func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %unpack_dest: tensor<?x?xf32>, %tile_p : index, %tile_q : index) -> tensor<?x?xf32> {
   %transposed = linalg.transpose
     ins(%arg0 : tensor<?x?x?x?xf32>)
@@ -563,17 +589,14 @@ func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile
     into %unpack_dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
   return %unpack : tensor<?x?xf32>
 }
-//       CHECK:    #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
 // CHECK-LABEL:   func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>,
 //  CHECK-SAME:     %[[IDX1:.+]]: index, %[[IDX2:.+]]: index) -> tensor<?x?xf32> {
 //   CHECK-DAG:       %[[CST1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:       %[[CST0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:       %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST0]] : tensor<?x?x?x?xf32>
-//   CHECK-DAG:       %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST1]] : tensor<?x?x?x?xf32>
-//   CHECK-DAG:       %[[AMAP0:.+]] = affine.apply #[[$MAP]]()[%[[DIM1]], %[[IDX2]]]
-//   CHECK-DAG:       %[[AMAP1:.+]] = affine.apply #[[$MAP]]()[%[[DIM0]], %[[IDX1]]]
-//       CHECK:       %[[OUT:.+]] = tensor.empty(%[[AMAP1]], %[[AMAP0]]) : tensor<?x?xf32>
+//   CHECK-DAG:       %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[CST0]] : tensor<?x?xf32>
+//   CHECK-DAG:       %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[CST1]] : tensor<?x?xf32>
+//       CHECK:       %[[OUT:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?xf32>
 //       CHECK:       %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
 //  CHECK-SAME:         outer_dims_perm = [0, 1]
 //  CHECK-SAME:         inner_dims_pos = [1, 0]

``````````

</details>


https://github.com/llvm/llvm-project/pull/107271


More information about the Mlir-commits mailing list