[Mlir-commits] [mlir] [mlir] Add pack/unpack transpose foldings for linalg.generic ops, fix bugs (PR #93055)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 22 08:39:41 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: None (Max191)

<details>
<summary>Changes</summary>

This PR adds transpose + pack/unpack folding support for transpose ops in the form of `linalg.generic` ops. There were also some bugs with the permutation composing in the previous patterns, so this PR fixes these bugs and adds tests for them as well.

---
Full diff: https://github.com/llvm/llvm-project/pull/93055.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+87-39) 
- (modified) mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir (+139) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index ebcb34e9ef024..ce5fda8e79d65 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -48,6 +48,34 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
   return success();
 }
 
+// If the `linalgOp` represents a transpose, return the permutation vector for
+// the transpose. Otherwise, return failure.
+static FailureOr<SmallVector<int64_t>>
+getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
+  if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
+    return SmallVector<int64_t>(transposeOp.getPermutation());
+  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+    return failure();
+
+  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
+    return failure();
+  auto mapRange = linalgOp.getIndexingMapsArray();
+  if (mapRange.size() != 2 || !mapRange.front().isPermutation() ||
+      !mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) {
+    return failure();
+  }
+  if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
+    return failure();
+  AffineMap outMap = mapRange.back();
+  AffineMap inMap = mapRange.front();
+  // To get the permutation, look at each output index and find which
+  // dimension in the input we're reading from for that index.
+  return llvm::map_to_vector(outMap.getResults(),
+                             [&](AffineExpr expr) -> int64_t {
+                               return *inMap.getResultPosition(expr);
+                             });
+}
+
 /// Packing one-dimensional tensor can be expressed as an expand shape op.
 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
   using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -244,14 +272,10 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
 
   for (unsigned int i = 0; i < rank; ++i) {
     int64_t remappedPosition = permutation[i];
-
-    if (!inVec.empty()) {
-      if (remappedPosition >= rank) {
-        return false;
-      }
+    if (remappedPosition >= rank)
+      return false;
+    if (!inVec.empty())
       remappedPosition = inVec[remappedPosition];
-    }
-
     resVec.push_back(remappedPosition);
   }
 
@@ -261,20 +285,26 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
 /// semantics.
 struct FoldProducerPackWithConsumerLinalgTransposeOp
-    : public OpRewritePattern<linalg::TransposeOp> {
-  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+    : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
+    auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
 
     if (!packOp)
       return failure();
 
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
+      return failure();
+    }
+
     auto innerDimsPos = packOp.getInnerDimsPos();
     auto mixedInnerTiles = packOp.getMixedTiles();
     auto outerDimsPerm = packOp.getOuterDimsPerm();
-    auto transposePerm = transposeOp.getPermutation();
+    auto transposePerm = maybePerm.value();
     SmallVector<int64_t> newOuterDimsPermVec;
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
@@ -283,7 +313,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
                          srcRank))
       return rewriter.notifyMatchFailure(
-          transposeOp,
+          linalgOp,
           "Cannot fold in tensor.pack if a tile dimension was transposed "
           "with a non-tile dimension in linalg.transpose.");
 
@@ -295,11 +325,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     }
 
     Value output = packOp.createDestinationTensor(
-        rewriter, transposeOp.getLoc(), packOp.getSource(),
-        newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+        rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
+        newInnerDimsPosVec, newOuterDimsPermVec);
 
     rewriter.replaceOpWithNewOp<PackOp>(
-        transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
+        linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
         newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
 
     return success();
@@ -314,12 +344,17 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 
   LogicalResult matchAndRewrite(PackOp packOp,
                                 PatternRewriter &rewriter) const override {
-    auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
+    auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
+    if (!linalgOp)
+      return failure();
 
-    if (!transposeOp)
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
       return failure();
+    }
 
-    auto transposePermutation = transposeOp.getPermutation();
+    auto transposePermutation = maybePerm.value();
     auto outerDimsPerm = packOp.getOuterDimsPerm();
     auto innerDimsPos = packOp.getInnerDimsPos();
     SmallVector<int64_t> newInnerDimsPosVec;
@@ -335,11 +370,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
       newInnerDimsPosVec.push_back(transposePermutation[dim]);
 
     Value output = packOp.createDestinationTensor(
-        rewriter, packOp.getLoc(), transposeOp.getOperand(0),
+        rewriter, packOp.getLoc(), linalgOp->getOperand(0),
         packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
 
     rewriter.replaceOpWithNewOp<PackOp>(
-        packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+        packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
         packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
 
     return success();
@@ -349,22 +384,29 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
 /// transpose semantics.
 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
-    : public OpRewritePattern<linalg::TransposeOp> {
-  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+    : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+    auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
 
     if (!unPackOp)
       return failure();
 
-    auto transposePermutation = transposeOp.getPermutation();
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
+      return failure();
+    }
+
+    auto transposePermutation = maybePerm.value();
+    SmallVector<int64_t> inverseTransposePerm =
+        invertPermutationVector(transposePermutation);
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
     auto innerDimsPos = unPackOp.getInnerDimsPos();
     SmallVector<int64_t> newInnerDimsPosVec;
-    SmallVector<int64_t> newOuterDimsPermVec =
-        llvm::to_vector(transposePermutation);
+    SmallVector<int64_t> newOuterDimsPermVec = inverseTransposePerm;
 
     if (!outerDimsPerm.empty())
       applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
@@ -372,11 +414,11 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
     // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
     // permutation rank won't necessarily be equal in all cases.
     for (auto dim : innerDimsPos)
-      newInnerDimsPosVec.push_back(transposePermutation[dim]);
+      newInnerDimsPosVec.push_back(inverseTransposePerm[dim]);
 
     // Reuse the destination of the transpose op.
     rewriter.replaceOpWithNewOp<UnPackOp>(
-        transposeOp, unPackOp.getSource(), transposeOp.getDpsInits()[0],
+        linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
         newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
 
     return success();
@@ -391,13 +433,19 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
 
   LogicalResult matchAndRewrite(UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
-    auto transposeOp =
-        unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
+    auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
+    if (!linalgOp)
+      return failure();
 
-    if (!transposeOp)
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
       return failure();
+    }
 
-    auto transposePermutation = transposeOp.getPermutation();
+    auto transposePermutation = maybePerm.value();
+    SmallVector<int64_t> inverseTransposePerm =
+        invertPermutationVector(transposePermutation);
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
     auto innerDimsPos = unPackOp.getInnerDimsPos();
     int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
@@ -406,7 +454,7 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
 
-    if (!checkAndPermute(transposePermutation, outerDimsPerm,
+    if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
                          newOuterDimsPermVec, destRank))
       return rewriter.notifyMatchFailure(
           unPackOp,
@@ -414,18 +462,18 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
           "with a non-tile dimension in linalg.transpose.");
 
     // Process transpose operation for tiled inner dimensions
-    for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
-      int64_t remappedPosition = transposePermutation[i] - destRank;
+    for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
+      int64_t remappedPosition = inverseTransposePerm[i] - destRank;
       newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
       newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
     }
 
     Value output = unPackOp.createDestinationTensor(
-        rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
+        rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
         newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
 
     rewriter.replaceOpWithNewOp<UnPackOp>(
-        unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+        unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
         newMixedInnerTilesVec, newOuterDimsPermVec);
 
     return success();
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 9f486f9146ad8..fca6eddaca436 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -636,3 +636,142 @@ func.func @tensor_padded_unpack_linalg_transpose_fold(%arg0: tensor<71x7x4x16x16
 //  CHECK-SAME:        into %[[OUT:.+]] : tensor<71x7x4x16x16xf32> -> tensor<100x71x64xf32>
 //       CHECK:       return %[[UNPACK]] : tensor<100x71x64xf32>
 //       CHECK:    }
+
+// -----
+
+func.func @non_involution_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+  %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+  %transposed = linalg.transpose ins(%arg0 : tensor<2x3x5x4x16xi32>)
+                outs(%0 : tensor<5x2x3x16x4xi32>)
+                permutation = [2, 0, 1, 4, 3]
+  %1 = tensor.empty() : tensor<5x48x8xi32>
+  %unpack = tensor.unpack %transposed
+            outer_dims_perm = [0, 2, 1]
+            inner_dims_pos = [1, 2]
+            inner_tiles = [16, 4] into
+            %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+  return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL:  func.func @non_involution_transpose_unpack_fold(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+//      CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+//      CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:        outer_dims_perm = [2, 1, 0]
+// CHECK-SAME:        inner_dims_pos = [2, 1]
+// CHECK-SAME:        inner_tiles = [4, 16]
+// CHEKC-SAME:        into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+//      CHECK:     return %[[UNPACK]] : tensor<5x48x8xi32>
+//      CHECK:   }
+
+// -----
+
+func.func @unpack_non_involution_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+  %0 = tensor.empty() : tensor<3x56x3648xf32>
+  %unpack = tensor.unpack %arg0
+    outer_dims_perm = [2, 0, 1]
+    inner_dims_pos = [1, 2]
+    inner_tiles = [1, 64]
+    into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+  %1 = tensor.empty() : tensor<3648x3x56xf32>
+  %transposed = linalg.transpose
+    ins(%unpack : tensor<3x56x3648xf32>)
+    outs(%1 : tensor<3648x3x56xf32>)
+    permutation = [2, 0, 1]
+  return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL:  func.func @unpack_non_involution_transpose_fold(
+//  CHECK-SAME:    %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+//       CHECK:        %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+//       CHECK:        %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+//  CHECK-SAME:        outer_dims_perm = [0, 1, 2]
+//  CHECK-SAME:        inner_dims_pos = [2, 0]
+//  CHECK-SAME:        inner_tiles = [1, 64]
+//  CHECK-SAME:        into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+//       CHECK:       return %[[UNPACK]] : tensor<3648x3x56xf32>
+//       CHECK:    }
+
+// -----
+
+func.func @transpose_unpacked_dims_no_fold(%arg0: tensor<2x16x5x4x3xi32>) -> tensor<5x32x12xi32> {
+  %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+  %transposed = linalg.transpose ins(%arg0 : tensor<2x16x5x4x3xi32>)
+                outs(%0 : tensor<5x2x3x16x4xi32>)
+                permutation = [2, 0, 4, 1, 3]
+  %1 = tensor.empty() : tensor<5x32x12xi32>
+  %unpack = tensor.unpack %transposed
+            inner_dims_pos = [1, 2]
+            inner_tiles = [16, 4] into
+            %1 : tensor<5x2x3x16x4xi32> -> tensor<5x32x12xi32>
+  return %unpack : tensor<5x32x12xi32>
+}
+//CHECK-LABEL:  func.func @transpose_unpacked_dims_no_fold(
+//      CHECK:     linalg.transpose
+//      CHECK:     tensor.unpack
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4)->(d1, d2, d0, d4, d3)>
+#map1 = affine_map<(d0, d1, d2, d3, d4)->(d0, d1, d2, d3, d4)>
+func.func @generic_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+  %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+  %transposed = linalg.generic {
+                iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+                indexing_maps = [#map, #map1]}
+                ins(%arg0 : tensor<2x3x5x4x16xi32>)
+                outs(%0 : tensor<5x2x3x16x4xi32>) {
+  ^bb0(%in : i32, %out : i32):
+    linalg.yield %in : i32
+  } -> tensor<5x2x3x16x4xi32>
+  %1 = tensor.empty() : tensor<5x48x8xi32>
+  %unpack = tensor.unpack %transposed
+            outer_dims_perm = [0, 2, 1]
+            inner_dims_pos = [1, 2]
+            inner_tiles = [16, 4] into
+            %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+  return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL:  func.func @generic_transpose_unpack_fold(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+//      CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+//      CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:        outer_dims_perm = [2, 1, 0]
+// CHECK-SAME:        inner_dims_pos = [2, 1]
+// CHECK-SAME:        inner_tiles = [4, 16]
+// CHEKC-SAME:        into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+//      CHECK:     return %[[UNPACK]] : tensor<5x48x8xi32>
+//      CHECK:   }
+
+// -----
+
+#map = affine_map<(d0, d1, d2)->(d1, d2, d0)>
+#map1 = affine_map<(d0, d1, d2)->(d0, d1, d2)>
+func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+  %0 = tensor.empty() : tensor<3x56x3648xf32>
+  %unpack = tensor.unpack %arg0
+    outer_dims_perm = [2, 0, 1]
+    inner_dims_pos = [1, 2]
+    inner_tiles = [1, 64]
+    into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+  %1 = tensor.empty() : tensor<3648x3x56xf32>
+  %transposed = linalg.generic {
+                iterator_types = ["parallel", "parallel", "parallel"],
+                indexing_maps = [#map, #map1]}
+                ins(%unpack : tensor<3x56x3648xf32>)
+                outs(%1 : tensor<3648x3x56xf32>) {
+  ^bb0(%in : f32, %out : f32):
+    linalg.yield %in : f32
+  } -> tensor<3648x3x56xf32>
+  return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL:  func.func @unpack_generic_transpose_fold(
+//  CHECK-SAME:    %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+//       CHECK:        %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+//       CHECK:        %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+//  CHECK-SAME:        outer_dims_perm = [0, 1, 2]
+//  CHECK-SAME:        inner_dims_pos = [2, 0]
+//  CHECK-SAME:        inner_tiles = [1, 64]
+//  CHECK-SAME:        into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+//       CHECK:       return %[[UNPACK]] : tensor<3648x3x56xf32>
+//       CHECK:    }

``````````

</details>


https://github.com/llvm/llvm-project/pull/93055


More information about the Mlir-commits mailing list