[Mlir-commits] [mlir] [mlir] Add pack/unpack transpose foldings for linalg.generic ops, fix bugs (PR #93055)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 4 13:17:14 PDT 2024


https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/93055

>From e4224468b3c5b244d0f8fdcaa858d10c43a3d882 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Tue, 21 May 2024 12:01:38 -0400
Subject: [PATCH 1/2] [mlir] Add pack transpose foldings for linalg.generic
 transpose ops and fix bugs

---
 .../Transforms/PackAndUnpackPatterns.cpp      | 126 +++++++++++-----
 .../Tensor/fold-into-pack-and-unpack.mlir     | 139 ++++++++++++++++++
 2 files changed, 226 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index 5d6e3ec9756af..dd68928d77497 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -48,6 +48,34 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
   return success();
 }
 
+// If the `linalgOp` represents a transpose, return the permutation vector for
+// the transpose. Otherwise, return failure.
+static FailureOr<SmallVector<int64_t>>
+getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
+  if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
+    return SmallVector<int64_t>(transposeOp.getPermutation());
+  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+    return failure();
+
+  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
+    return failure();
+  auto mapRange = linalgOp.getIndexingMapsArray();
+  if (mapRange.size() != 2 || !mapRange.front().isPermutation() ||
+      !mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) {
+    return failure();
+  }
+  if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
+    return failure();
+  AffineMap outMap = mapRange.back();
+  AffineMap inMap = mapRange.front();
+  // To get the permutation, look at each output index and find which
+  // dimension in the input we're reading from for that index.
+  return llvm::map_to_vector(outMap.getResults(),
+                             [&](AffineExpr expr) -> int64_t {
+                               return *inMap.getResultPosition(expr);
+                             });
+}
+
 /// Packing one-dimensional tensor can be expressed as an expand shape op.
 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
   using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -246,14 +274,10 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
 
   for (unsigned int i = 0; i < rank; ++i) {
     int64_t remappedPosition = permutation[i];
-
-    if (!inVec.empty()) {
-      if (remappedPosition >= rank) {
-        return false;
-      }
+    if (remappedPosition >= rank)
+      return false;
+    if (!inVec.empty())
       remappedPosition = inVec[remappedPosition];
-    }
-
     resVec.push_back(remappedPosition);
   }
 
@@ -263,20 +287,26 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
 /// semantics.
 struct FoldProducerPackWithConsumerLinalgTransposeOp
-    : public OpRewritePattern<linalg::TransposeOp> {
-  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+    : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
+    auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
 
     if (!packOp)
       return failure();
 
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
+      return failure();
+    }
+
     auto innerDimsPos = packOp.getInnerDimsPos();
     auto mixedInnerTiles = packOp.getMixedTiles();
     auto outerDimsPerm = packOp.getOuterDimsPerm();
-    auto transposePerm = transposeOp.getPermutation();
+    auto transposePerm = maybePerm.value();
     SmallVector<int64_t> newOuterDimsPermVec;
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
@@ -285,7 +315,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
                          srcRank))
       return rewriter.notifyMatchFailure(
-          transposeOp,
+          linalgOp,
           "Cannot fold in tensor.pack if a tile dimension was transposed "
           "with a non-tile dimension in linalg.transpose.");
 
@@ -297,11 +327,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     }
 
     Value output = packOp.createDestinationTensor(
-        rewriter, transposeOp.getLoc(), packOp.getSource(),
-        newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+        rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
+        newInnerDimsPosVec, newOuterDimsPermVec);
 
     rewriter.replaceOpWithNewOp<PackOp>(
-        transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
+        linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
         newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
 
     return success();
@@ -316,12 +346,17 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 
   LogicalResult matchAndRewrite(PackOp packOp,
                                 PatternRewriter &rewriter) const override {
-    auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
+    auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
+    if (!linalgOp)
+      return failure();
 
-    if (!transposeOp)
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
       return failure();
+    }
 
-    auto transposePermutation = transposeOp.getPermutation();
+    auto transposePermutation = maybePerm.value();
     auto outerDimsPerm = packOp.getOuterDimsPerm();
     auto innerDimsPos = packOp.getInnerDimsPos();
     SmallVector<int64_t> newInnerDimsPosVec;
@@ -337,11 +372,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
       newInnerDimsPosVec.push_back(transposePermutation[dim]);
 
     Value output = packOp.createDestinationTensor(
-        rewriter, packOp.getLoc(), transposeOp.getOperand(0),
+        rewriter, packOp.getLoc(), linalgOp->getOperand(0),
         packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
 
     rewriter.replaceOpWithNewOp<PackOp>(
-        packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+        packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
         packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
 
     return success();
@@ -351,22 +386,29 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
 /// transpose semantics.
 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
-    : public OpRewritePattern<linalg::TransposeOp> {
-  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+    : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+    auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
 
     if (!unPackOp)
       return failure();
 
-    auto transposePermutation = transposeOp.getPermutation();
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
+      return failure();
+    }
+
+    auto transposePermutation = maybePerm.value();
+    SmallVector<int64_t> inverseTransposePerm =
+        invertPermutationVector(transposePermutation);
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
     auto innerDimsPos = unPackOp.getInnerDimsPos();
     SmallVector<int64_t> newInnerDimsPosVec;
-    SmallVector<int64_t> newOuterDimsPermVec =
-        llvm::to_vector(transposePermutation);
+    SmallVector<int64_t> newOuterDimsPermVec = inverseTransposePerm;
 
     if (!outerDimsPerm.empty())
       applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
@@ -374,11 +416,11 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
     // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
     // permutation rank won't necessarily be equal in all cases.
     for (auto dim : innerDimsPos)
-      newInnerDimsPosVec.push_back(transposePermutation[dim]);
+      newInnerDimsPosVec.push_back(inverseTransposePerm[dim]);
 
     // Reuse the destination of the transpose op.
     rewriter.replaceOpWithNewOp<UnPackOp>(
-        transposeOp, unPackOp.getSource(), transposeOp.getDpsInits()[0],
+        linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
         newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
 
     return success();
@@ -393,13 +435,19 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
 
   LogicalResult matchAndRewrite(UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
-    auto transposeOp =
-        unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
+    auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
+    if (!linalgOp)
+      return failure();
 
-    if (!transposeOp)
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
       return failure();
+    }
 
-    auto transposePermutation = transposeOp.getPermutation();
+    auto transposePermutation = maybePerm.value();
+    SmallVector<int64_t> inverseTransposePerm =
+        invertPermutationVector(transposePermutation);
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
     auto innerDimsPos = unPackOp.getInnerDimsPos();
     int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
@@ -408,7 +456,7 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
 
-    if (!checkAndPermute(transposePermutation, outerDimsPerm,
+    if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
                          newOuterDimsPermVec, destRank))
       return rewriter.notifyMatchFailure(
           unPackOp,
@@ -416,18 +464,18 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
           "with a non-tile dimension in linalg.transpose.");
 
     // Process transpose operation for tiled inner dimensions
-    for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
-      int64_t remappedPosition = transposePermutation[i] - destRank;
+    for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
+      int64_t remappedPosition = inverseTransposePerm[i] - destRank;
       newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
       newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
     }
 
     Value output = unPackOp.createDestinationTensor(
-        rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
+        rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
         newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
 
     rewriter.replaceOpWithNewOp<UnPackOp>(
-        unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+        unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
         newMixedInnerTilesVec, newOuterDimsPermVec);
 
     return success();
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 9a3143f5e550e..629a4c2135720 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -636,3 +636,142 @@ func.func @tensor_padded_unpack_linalg_transpose_fold(%arg0: tensor<71x7x4x16x16
 //  CHECK-SAME:        into %[[OUT:.+]] : tensor<71x7x4x16x16xf32> -> tensor<100x71x64xf32>
 //       CHECK:       return %[[UNPACK]] : tensor<100x71x64xf32>
 //       CHECK:    }
+
+// -----
+
+func.func @non_involution_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+  %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+  %transposed = linalg.transpose ins(%arg0 : tensor<2x3x5x4x16xi32>)
+                outs(%0 : tensor<5x2x3x16x4xi32>)
+                permutation = [2, 0, 1, 4, 3]
+  %1 = tensor.empty() : tensor<5x48x8xi32>
+  %unpack = tensor.unpack %transposed
+            outer_dims_perm = [0, 2, 1]
+            inner_dims_pos = [1, 2]
+            inner_tiles = [16, 4] into
+            %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+  return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL:  func.func @non_involution_transpose_unpack_fold(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+//      CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+//      CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:        outer_dims_perm = [2, 1, 0]
+// CHECK-SAME:        inner_dims_pos = [2, 1]
+// CHECK-SAME:        inner_tiles = [4, 16]
+// CHEKC-SAME:        into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+//      CHECK:     return %[[UNPACK]] : tensor<5x48x8xi32>
+//      CHECK:   }
+
+// -----
+
+func.func @unpack_non_involution_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+  %0 = tensor.empty() : tensor<3x56x3648xf32>
+  %unpack = tensor.unpack %arg0
+    outer_dims_perm = [2, 0, 1]
+    inner_dims_pos = [1, 2]
+    inner_tiles = [1, 64]
+    into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+  %1 = tensor.empty() : tensor<3648x3x56xf32>
+  %transposed = linalg.transpose
+    ins(%unpack : tensor<3x56x3648xf32>)
+    outs(%1 : tensor<3648x3x56xf32>)
+    permutation = [2, 0, 1]
+  return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL:  func.func @unpack_non_involution_transpose_fold(
+//  CHECK-SAME:    %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+//       CHECK:        %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+//       CHECK:        %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+//  CHECK-SAME:        outer_dims_perm = [0, 1, 2]
+//  CHECK-SAME:        inner_dims_pos = [2, 0]
+//  CHECK-SAME:        inner_tiles = [1, 64]
+//  CHECK-SAME:        into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+//       CHECK:       return %[[UNPACK]] : tensor<3648x3x56xf32>
+//       CHECK:    }
+
+// -----
+
+func.func @transpose_unpacked_dims_no_fold(%arg0: tensor<2x16x5x4x3xi32>) -> tensor<5x32x12xi32> {
+  %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+  %transposed = linalg.transpose ins(%arg0 : tensor<2x16x5x4x3xi32>)
+                outs(%0 : tensor<5x2x3x16x4xi32>)
+                permutation = [2, 0, 4, 1, 3]
+  %1 = tensor.empty() : tensor<5x32x12xi32>
+  %unpack = tensor.unpack %transposed
+            inner_dims_pos = [1, 2]
+            inner_tiles = [16, 4] into
+            %1 : tensor<5x2x3x16x4xi32> -> tensor<5x32x12xi32>
+  return %unpack : tensor<5x32x12xi32>
+}
+//CHECK-LABEL:  func.func @transpose_unpacked_dims_no_fold(
+//      CHECK:     linalg.transpose
+//      CHECK:     tensor.unpack
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4)->(d1, d2, d0, d4, d3)>
+#map1 = affine_map<(d0, d1, d2, d3, d4)->(d0, d1, d2, d3, d4)>
+func.func @generic_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+  %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+  %transposed = linalg.generic {
+                iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+                indexing_maps = [#map, #map1]}
+                ins(%arg0 : tensor<2x3x5x4x16xi32>)
+                outs(%0 : tensor<5x2x3x16x4xi32>) {
+  ^bb0(%in : i32, %out : i32):
+    linalg.yield %in : i32
+  } -> tensor<5x2x3x16x4xi32>
+  %1 = tensor.empty() : tensor<5x48x8xi32>
+  %unpack = tensor.unpack %transposed
+            outer_dims_perm = [0, 2, 1]
+            inner_dims_pos = [1, 2]
+            inner_tiles = [16, 4] into
+            %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+  return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL:  func.func @generic_transpose_unpack_fold(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+//      CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+//      CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:        outer_dims_perm = [2, 1, 0]
+// CHECK-SAME:        inner_dims_pos = [2, 1]
+// CHECK-SAME:        inner_tiles = [4, 16]
+// CHEKC-SAME:        into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+//      CHECK:     return %[[UNPACK]] : tensor<5x48x8xi32>
+//      CHECK:   }
+
+// -----
+
+#map = affine_map<(d0, d1, d2)->(d1, d2, d0)>
+#map1 = affine_map<(d0, d1, d2)->(d0, d1, d2)>
+func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+  %0 = tensor.empty() : tensor<3x56x3648xf32>
+  %unpack = tensor.unpack %arg0
+    outer_dims_perm = [2, 0, 1]
+    inner_dims_pos = [1, 2]
+    inner_tiles = [1, 64]
+    into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+  %1 = tensor.empty() : tensor<3648x3x56xf32>
+  %transposed = linalg.generic {
+                iterator_types = ["parallel", "parallel", "parallel"],
+                indexing_maps = [#map, #map1]}
+                ins(%unpack : tensor<3x56x3648xf32>)
+                outs(%1 : tensor<3648x3x56xf32>) {
+  ^bb0(%in : f32, %out : f32):
+    linalg.yield %in : f32
+  } -> tensor<3648x3x56xf32>
+  return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL:  func.func @unpack_generic_transpose_fold(
+//  CHECK-SAME:    %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+//       CHECK:        %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+//       CHECK:        %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+//  CHECK-SAME:        outer_dims_perm = [0, 1, 2]
+//  CHECK-SAME:        inner_dims_pos = [2, 0]
+//  CHECK-SAME:        inner_tiles = [1, 64]
+//  CHECK-SAME:        into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+//       CHECK:       return %[[UNPACK]] : tensor<3648x3x56xf32>
+//       CHECK:    }

>From 0d667d8740c695288b7e7891e5fc539e906d94b5 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Tue, 4 Jun 2024 16:17:04 -0400
Subject: [PATCH 2/2] address comments

---
 .../Transforms/PackAndUnpackPatterns.cpp      | 33 ++++++++-----------
 1 file changed, 13 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index dd68928d77497..c681cadcb27cb 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -60,8 +60,8 @@ getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
   if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
     return failure();
   auto mapRange = linalgOp.getIndexingMapsArray();
-  if (mapRange.size() != 2 || !mapRange.front().isPermutation() ||
-      !mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) {
+  if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
+      mapRange.front() == mapRange.back()) {
     return failure();
   }
   if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
@@ -299,9 +299,8 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
 
     FailureOr<SmallVector<int64_t>> maybePerm =
         getTransposeOpPermutation(linalgOp);
-    if (failed(maybePerm)) {
+    if (failed(maybePerm))
       return failure();
-    }
 
     auto innerDimsPos = packOp.getInnerDimsPos();
     auto mixedInnerTiles = packOp.getMixedTiles();
@@ -352,9 +351,8 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 
     FailureOr<SmallVector<int64_t>> maybePerm =
         getTransposeOpPermutation(linalgOp);
-    if (failed(maybePerm)) {
+    if (failed(maybePerm))
       return failure();
-    }
 
     auto transposePermutation = maybePerm.value();
     auto outerDimsPerm = packOp.getOuterDimsPerm();
@@ -398,25 +396,22 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
 
     FailureOr<SmallVector<int64_t>> maybePerm =
         getTransposeOpPermutation(linalgOp);
-    if (failed(maybePerm)) {
+    if (failed(maybePerm))
       return failure();
-    }
 
-    auto transposePermutation = maybePerm.value();
-    SmallVector<int64_t> inverseTransposePerm =
-        invertPermutationVector(transposePermutation);
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
     auto innerDimsPos = unPackOp.getInnerDimsPos();
     SmallVector<int64_t> newInnerDimsPosVec;
-    SmallVector<int64_t> newOuterDimsPermVec = inverseTransposePerm;
-
-    if (!outerDimsPerm.empty())
-      applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
+    SmallVector<int64_t> newOuterDimsPermVec =
+        invertPermutationVector(maybePerm.value());
 
     // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
     // permutation rank won't necessarily be equal in all cases.
     for (auto dim : innerDimsPos)
-      newInnerDimsPosVec.push_back(inverseTransposePerm[dim]);
+      newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
+
+    if (!outerDimsPerm.empty())
+      applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
 
     // Reuse the destination of the transpose op.
     rewriter.replaceOpWithNewOp<UnPackOp>(
@@ -441,13 +436,11 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
 
     FailureOr<SmallVector<int64_t>> maybePerm =
         getTransposeOpPermutation(linalgOp);
-    if (failed(maybePerm)) {
+    if (failed(maybePerm))
       return failure();
-    }
 
-    auto transposePermutation = maybePerm.value();
     SmallVector<int64_t> inverseTransposePerm =
-        invertPermutationVector(transposePermutation);
+        invertPermutationVector(maybePerm.value());
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
     auto innerDimsPos = unPackOp.getInnerDimsPos();
     int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();



More information about the Mlir-commits mailing list