[llvm] [mlir] [mlir][tensor] Fold linalg transpose with tensor pack (PR #74206)

Prathamesh Tagore via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 8 12:27:39 PST 2023


https://github.com/meshtag updated https://github.com/llvm/llvm-project/pull/74206

>From febd3de154543df1a4828994bfe716fee5cf4e52 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Thu, 30 Nov 2023 21:44:17 +0000
Subject: [PATCH 1/8] First working draft for both cases

---
 .../FoldIntoPackAndUnpackPatterns.cpp         |  81 +++++++++++-
 .../Tensor/fold-into-pack-and-unpack.mlir     | 121 +++++-------------
 result.mlir                                   |   0
 3 files changed, 110 insertions(+), 92 deletions(-)
 create mode 100644 result.mlir

diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index 9eac3e5c7ef91..ca69341afb51c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -6,11 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
 #include "llvm/Support/Debug.h"
 
+#include <iostream>
+
 namespace mlir {
 namespace tensor {
 namespace {
@@ -81,10 +84,86 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
     return success();
   }
 };
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+    : public OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto transposeInputTensor = transposeOp.getOperand(0);
+    auto packOp = transposeInputTensor.getDefiningOp<PackOp>();
+
+    if (!packOp)
+      return failure();
+
+    auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+    auto transposePerm = transposeOp.getPermutation();
+    llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+    for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+      newPackOuterDimsPermVec.push_back(packOuterDimsPerm[transposePerm[i]]);
+
+    // Create a new empty output tensor.
+    Type elementType = packOp.getDestType().getElementType();
+    auto transposeOpResultType = transposeOp.getResult().getType()[0];
+    auto rankedTensorType = transposeOpResultType.dyn_cast<RankedTensorType>();
+    Value output = rewriter.create<EmptyOp>(
+        transposeOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+    rewriter.replaceOpWithNewOp<PackOp>(
+        transposeOp, packOp.getSource(), output, packOp.getInnerDimsPos(),
+        packOp.getMixedTiles(), std::nullopt,
+        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+    return success();
+  }
+};
+
+/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldConsumerPackWithProducerLinalgTransposeOp
+    : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto packInputTensor = packOp.getOperand(0);
+    auto transposeOp = packInputTensor.getDefiningOp<linalg::TransposeOp>();
+
+    if (!transposeOp)
+      return failure();
+
+    auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+    auto transposePerm = transposeOp.getPermutation();
+    llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+    for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+      newPackOuterDimsPermVec.push_back(transposePerm[packOuterDimsPerm[i]]);
+
+    // Create a new empty output tensor.
+    Type elementType = packOp.getDestType().getElementType();
+    auto packOpResultType = packOp.getResult().getType();
+    auto rankedTensorType = packOpResultType.dyn_cast<RankedTensorType>();
+    Value output = rewriter.create<EmptyOp>(
+        packOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+    rewriter.replaceOpWithNewOp<PackOp>(
+        packOp, transposeOp.getOperand(0), output, packOp.getInnerDimsPos(),
+        packOp.getMixedTiles(), std::nullopt,
+        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+    return success();
+  }
+};
 } // namespace
 
 void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
-  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp>(
+  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
+                  FoldProducerPackWithConsumerLinalgTransposeOp,
+                  FoldConsumerPackWithProducerLinalgTransposeOp>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 5c75789665742..bc67b75480383 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -17,100 +17,39 @@ func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32
 // CHECK-SAME:       into %[[INIT]]
 //      CHECK:   return %[[UNPACK]]
 
-// -----
-
-func.func @nofold_unpack_slice_non_zero_offset(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
-    %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
-  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
-      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
-  %1 = tensor.extract_slice %0[0, %arg4] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
-}
-// CHECK-LABEL: func @nofold_unpack_slice_non_zero_offset(
-//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
-//       CHECK:   tensor.extract_slice %[[UNPACK]]
-
-// -----
-
-func.func @nofold_unpack_slice_non_unit_stride(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
-    %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
-  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
-      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
-  %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [%arg4, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
+func.func @foo(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+  %0 = tensor.empty() : tensor<1x56x57x64xf32>
+  %transposed = linalg.transpose
+    ins(%arg0 : tensor<56x57x1x64xf32>)
+    outs(%0 : tensor<1x56x57x64xf32>)
+    permutation = [2, 0, 1, 3]
+  %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+
+  // [2, 3, 0, 1]
+
+  %pack = tensor.pack %transposed
+    outer_dims_perm = [0, 3, 1, 2]
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %1 : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
+  return %pack : tensor<1x2x56x57x32xf32>
 }
-// CHECK-LABEL: func @nofold_unpack_slice_non_unit_stride(
-//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
-//       CHECK:   tensor.extract_slice %[[UNPACK]]
 
-// -----
 
-func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
-    %arg2 : index, %arg3 : index) -> tensor<f32> {
-  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
-      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
-  %1 = tensor.extract_slice %0[0, 0] [1, 1] [1, 1] : tensor<?x?xf32> to tensor<f32>
-  return %1 : tensor<f32>
-}
-// CHECK-LABEL: func @nofold_unpack_slice_rank_reduced(
-//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
-//       CHECK:   tensor.extract_slice %[[UNPACK]]
-
-// -----
-
-func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.000000e+00 : f32
-  %padded = tensor.pad %src low[0, 0] high[15, 0] {
-  ^bb0(%arg0: index, %arg1: index):
-    tensor.yield %cst : f32
-  } : tensor<16641x16xf32> to tensor<16656x16xf32>
-  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
-  %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
-      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
-  return %pack : tensor<2082x1x8x32xf32>
-}
-// CHECK-LABEL: func.func @pad_pack
-// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
-// CHECK:         %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32>
-// CHECK:         %[[PACK:.+]] = tensor.pack %[[SRC]]
-// CHECK-SAME:      padding_value(%[[PAD_VAL]] : f32)
-// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]]
-
-// -----
-
-func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.000000e+00 : f32
-  %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
-  ^bb0(%arg0: index, %arg1: index):
-    tensor.yield %cst : f32
-  } : tensor<16641x16xf32> to tensor<16656x16xf32>
-  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
-  %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
-      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
-  return %pack : tensor<2082x1x8x32xf32>
-}
-// CHECK-LABEL: func.func @nofold_pad_pack
-// CHECK:         tensor.pad
-// CHECK:         tensor.pack
+func.func @foo1(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+  %0 = tensor.empty() : tensor<56x57x1x2x32xf32>
+  %pack = tensor.pack %arg0
+    outer_dims_perm = [0, 1, 2, 3]
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
 
-// -----
+    // [2, 3, 0, 1]
 
-func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
-  %c0 = arith.constant 0 : index
-  %cst0 = arith.constant 0.000000e+00 : f32
-  %cst1 = arith.constant 1.000000e+00 : f32
-  %padded = tensor.pad %src low[0, 0] high[15, 0] {
-  ^bb0(%arg0: index, %arg1: index):
-    tensor.yield %cst0 : f32
-  } : tensor<16641x16xf32> to tensor<16656x16xf32>
-  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
-  %pack = tensor.pack %padded padding_value(%cst1 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
-      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
-  return %pack : tensor<2082x1x8x32xf32>
+    %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+    %transposed = linalg.transpose
+    ins(%pack : tensor<56x57x1x2x32xf32>)
+    outs(%1 : tensor<1x2x56x57x32xf32>)
+    permutation = [2, 3, 0, 1, 4]
+  return %transposed : tensor<1x2x56x57x32xf32>
 }
-// CHECK-LABEL: func.func @pad_pack_different_padding_value
-// CHECK:         tensor.pad
-// CHECK:         tensor.pack
diff --git a/result.mlir b/result.mlir
new file mode 100644
index 0000000000000..e69de29bb2d1d

>From 57e33eb73e6ff6d498d0cba511419b4245ca3399 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Sat, 2 Dec 2023 19:59:54 +0000
Subject: [PATCH 2/8] Add tests

---
 .../Tensor/fold-into-pack-and-unpack.mlir     | 125 +++++++++++++++++-
 1 file changed, 118 insertions(+), 7 deletions(-)

diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index bc67b75480383..b9b37680cdf47 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -17,7 +17,105 @@ func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32
 // CHECK-SAME:       into %[[INIT]]
 //      CHECK:   return %[[UNPACK]]
 
-func.func @foo(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+// -----
+
+func.func @nofold_unpack_slice_non_zero_offset(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0[0, %arg4] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_non_zero_offset(
+//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
+//       CHECK:   tensor.extract_slice %[[UNPACK]]
+
+// -----
+
+func.func @nofold_unpack_slice_non_unit_stride(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [%arg4, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_non_unit_stride(
+//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
+//       CHECK:   tensor.extract_slice %[[UNPACK]]
+
+// -----
+
+func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : index, %arg3 : index) -> tensor<f32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0[0, 0] [1, 1] [1, 1] : tensor<?x?xf32> to tensor<f32>
+  return %1 : tensor<f32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_rank_reduced(
+//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
+//       CHECK:   tensor.extract_slice %[[UNPACK]]
+
+// -----
+
+func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %src low[0, 0] high[15, 0] {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst : f32
+  } : tensor<16641x16xf32> to tensor<16656x16xf32>
+  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+  %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+  return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @pad_pack
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK:         %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[SRC]]
+// CHECK-SAME:      padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]]
+
+// -----
+
+func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst : f32
+  } : tensor<16641x16xf32> to tensor<16656x16xf32>
+  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+  %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+  return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @nofold_pad_pack
+// CHECK:         tensor.pad
+// CHECK:         tensor.pack
+
+// -----
+
+func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+  %c0 = arith.constant 0 : index
+  %cst0 = arith.constant 0.000000e+00 : f32
+  %cst1 = arith.constant 1.000000e+00 : f32
+  %padded = tensor.pad %src low[0, 0] high[15, 0] {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst0 : f32
+  } : tensor<16641x16xf32> to tensor<16656x16xf32>
+  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+  %pack = tensor.pack %padded padding_value(%cst1 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+  return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @pad_pack_different_padding_value
+// CHECK:         tensor.pad
+// CHECK:         tensor.pack
+
+func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
   %0 = tensor.empty() : tensor<1x56x57x64xf32>
   %transposed = linalg.transpose
     ins(%arg0 : tensor<56x57x1x64xf32>)
@@ -25,8 +123,6 @@ func.func @foo(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
     permutation = [2, 0, 1, 3]
   %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
 
-  // [2, 3, 0, 1]
-
   %pack = tensor.pack %transposed
     outer_dims_perm = [0, 3, 1, 2]
     inner_dims_pos = [3]
@@ -34,9 +130,18 @@ func.func @foo(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
     into %1 : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
   return %pack : tensor<1x2x56x57x32xf32>
 }
+//      CHECK: func @linalg_transpose_tensor_pack_fold(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]
 
+// -----
 
-func.func @foo1(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
   %0 = tensor.empty() : tensor<56x57x1x2x32xf32>
   %pack = tensor.pack %arg0
     outer_dims_perm = [0, 1, 2, 3]
@@ -44,12 +149,18 @@ func.func @foo1(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
     inner_tiles = [32]
     into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
 
-    // [2, 3, 0, 1]
-
-    %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+   %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
     %transposed = linalg.transpose
     ins(%pack : tensor<56x57x1x2x32xf32>)
     outs(%1 : tensor<1x2x56x57x32xf32>)
     permutation = [2, 3, 0, 1, 4]
   return %transposed : tensor<1x2x56x57x32xf32>
 }
+//      CHECK: func @tensor_pack_linalg_transpose_fold(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]

>From 354a61dcdc0226b3d1aaba263e7f2b41b325e96f Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Sat, 2 Dec 2023 20:02:12 +0000
Subject: [PATCH 3/8] Remove unnecessary stuff

---
 .../Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp | 2 --
 result.mlir                                                     | 0
 2 files changed, 2 deletions(-)
 delete mode 100644 result.mlir

diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index ca69341afb51c..47d85a6f4f9a5 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -12,8 +12,6 @@
 #include "mlir/IR/PatternMatch.h"
 #include "llvm/Support/Debug.h"
 
-#include <iostream>
-
 namespace mlir {
 namespace tensor {
 namespace {
diff --git a/result.mlir b/result.mlir
deleted file mode 100644
index e69de29bb2d1d..0000000000000

>From 1b9d96d0a111d863a7cc6e66837e55a11f8300bd Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Sat, 2 Dec 2023 20:09:56 +0000
Subject: [PATCH 4/8] Rectify formatting

---
 mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index b9b37680cdf47..0b00c7fa7feb9 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -121,8 +121,8 @@ func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> t
     ins(%arg0 : tensor<56x57x1x64xf32>)
     outs(%0 : tensor<1x56x57x64xf32>)
     permutation = [2, 0, 1, 3]
-  %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
 
+  %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
   %pack = tensor.pack %transposed
     outer_dims_perm = [0, 3, 1, 2]
     inner_dims_pos = [3]
@@ -149,8 +149,8 @@ func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> t
     inner_tiles = [32]
     into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
 
-   %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
-    %transposed = linalg.transpose
+  %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+  %transposed = linalg.transpose
     ins(%pack : tensor<56x57x1x2x32xf32>)
     outs(%1 : tensor<1x2x56x57x32xf32>)
     permutation = [2, 3, 0, 1, 4]

>From 840ba3ee5063aa7b4633e681d412786ee37f51b4 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Mon, 4 Dec 2023 14:44:50 +0000
Subject: [PATCH 5/8] Apply first batch of suggested changes

---
 .../FoldIntoPackAndUnpackPatterns.cpp         | 28 ++++++++-----------
 1 file changed, 12 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index 47d85a6f4f9a5..ee75205cba6cf 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -99,21 +99,19 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
 
     auto packOuterDimsPerm = packOp.getOuterDimsPerm();
     auto transposePerm = transposeOp.getPermutation();
-    llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+    SmallVector<int64_t> newPackOuterDimsPermVec;
 
     for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
       newPackOuterDimsPermVec.push_back(packOuterDimsPerm[transposePerm[i]]);
 
-    // Create a new empty output tensor.
-    Type elementType = packOp.getDestType().getElementType();
-    auto transposeOpResultType = transposeOp.getResult().getType()[0];
-    auto rankedTensorType = transposeOpResultType.dyn_cast<RankedTensorType>();
-    Value output = rewriter.create<EmptyOp>(
-        transposeOp.getLoc(), rankedTensorType.getShape(), elementType);
+    Value output = packOp.createDestinationTensor(
+        rewriter, transposeOp.getLoc(), packOp.getSource(),
+        packOp.getMixedTiles(), packOp.getInnerDimsPos(),
+        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
 
     rewriter.replaceOpWithNewOp<PackOp>(
         transposeOp, packOp.getSource(), output, packOp.getInnerDimsPos(),
-        packOp.getMixedTiles(), std::nullopt,
+        packOp.getMixedTiles(), /*paddingValue=*/std::nullopt,
         static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
 
     return success();
@@ -136,21 +134,19 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 
     auto packOuterDimsPerm = packOp.getOuterDimsPerm();
     auto transposePerm = transposeOp.getPermutation();
-    llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+    SmallVector<int64_t> newPackOuterDimsPermVec;
 
     for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
       newPackOuterDimsPermVec.push_back(transposePerm[packOuterDimsPerm[i]]);
 
-    // Create a new empty output tensor.
-    Type elementType = packOp.getDestType().getElementType();
-    auto packOpResultType = packOp.getResult().getType();
-    auto rankedTensorType = packOpResultType.dyn_cast<RankedTensorType>();
-    Value output = rewriter.create<EmptyOp>(
-        packOp.getLoc(), rankedTensorType.getShape(), elementType);
+    Value output = packOp.createDestinationTensor(
+        rewriter, packOp.getLoc(), transposeOp.getOperand(0),
+        packOp.getMixedTiles(), packOp.getInnerDimsPos(),
+        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
 
     rewriter.replaceOpWithNewOp<PackOp>(
         packOp, transposeOp.getOperand(0), output, packOp.getInnerDimsPos(),
-        packOp.getMixedTiles(), std::nullopt,
+        packOp.getMixedTiles(), /*paddingValue=*/std::nullopt,
         static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
 
     return success();

>From 3bed2bb95a4717fb31a2ecdba7c911fbbfbf3e0d Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Fri, 8 Dec 2023 17:55:25 +0000
Subject: [PATCH 6/8] Reference commit for inner tiles transposition

---
 .../FoldIntoPackAndUnpackPatterns.cpp         | 58 +++++++++++++++++--
 1 file changed, 53 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index ee75205cba6cf..7be771f820a53 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -97,21 +97,69 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     if (!packOp)
       return failure();
 
+    auto packInnerDimsPos = packOp.getInnerDimsPos();
+    auto packInnerTiles = packOp.getStaticInnerTiles();
     auto packOuterDimsPerm = packOp.getOuterDimsPerm();
     auto transposePerm = transposeOp.getPermutation();
     SmallVector<int64_t> newPackOuterDimsPermVec;
+    SmallVector<int64_t> newPackInnerDimsPosVec;
+    SmallVector<int64_t> newPackInnerTilesVec;
+
+    // Variable for storing translated position after considering original
+    // outer_dims_perm and permutation attributes of tensor.pack and
+    // linalg.transpose.
+    int64_t translatedPosition;
+
+    // Process transpose operation for non-tiled outer dimensions of the tensor.
+    for (unsigned int i = 0; i < transposePerm.size() - packInnerTiles.size();
+         ++i) {
+      // If tensor.pack has outer_dims_perm attribute, then consider it during
+      // index translation.
+      if (packOuterDimsPerm.size())
+        translatedPosition = packOuterDimsPerm[transposePerm[i]];
+      else
+        translatedPosition = transposePerm[i];
+
+      // Cannot fold in pack if a tile dimension was transposed with a non-tile
+      // dimension.
+      if (translatedPosition >= transposePerm.size() - packInnerTiles.size())
+        return failure();
 
-    for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
-      newPackOuterDimsPermVec.push_back(packOuterDimsPerm[transposePerm[i]]);
+      newPackOuterDimsPermVec.push_back(translatedPosition);
+    }
+
+    // Process transpose operation for tiled inner dimensions of the tensor.
+    for (unsigned int i = transposePerm.size() - packInnerTiles.size();
+         i < transposePerm.size(); ++i) {
+      translatedPosition =
+          transposePerm[i] - (transposePerm.size() - packInnerTiles.size());
+
+      newPackInnerTilesVec.push_back(packInnerTiles[translatedPosition]);
+      newPackInnerDimsPosVec.push_back(packInnerDimsPos[translatedPosition]);
+    }
+
+    llvm::SmallVector<OpFoldResult, 4> opFoldResultsTiles;
+    opFoldResultsTiles.reserve(newPackInnerTilesVec.size());
+
+    llvm::transform(
+        newPackInnerTilesVec, std::back_inserter(opFoldResultsTiles),
+        [&rewriter](int64_t value) {
+          return IntegerAttr::get(IndexType::get(rewriter.getContext()), value);
+        });
+
+    llvm::ArrayRef<OpFoldResult> newPackInnerTilesArrayRef(opFoldResultsTiles);
 
     Value output = packOp.createDestinationTensor(
         rewriter, transposeOp.getLoc(), packOp.getSource(),
-        packOp.getMixedTiles(), packOp.getInnerDimsPos(),
+        newPackInnerTilesArrayRef,
+        static_cast<llvm::ArrayRef<int64_t>>(newPackInnerDimsPosVec),
         static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
 
     rewriter.replaceOpWithNewOp<PackOp>(
-        transposeOp, packOp.getSource(), output, packOp.getInnerDimsPos(),
-        packOp.getMixedTiles(), /*paddingValue=*/std::nullopt,
+        transposeOp, packOp.getSource(), output,
+        static_cast<llvm::ArrayRef<int64_t>>(newPackInnerDimsPosVec),
+        newPackInnerTilesArrayRef,
+        /*paddingValue=*/std::nullopt,
         static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
 
     return success();

>From 1fd838bf1f02c9c4e058ed20021695c08e82b9d2 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Fri, 8 Dec 2023 18:37:58 +0000
Subject: [PATCH 7/8] Reference commit for transpose_pack_ordering pattern

---
 .../FoldIntoPackAndUnpackPatterns.cpp         | 104 ++++++++++++++----
 .../Tensor/fold-into-pack-and-unpack.mlir     |   2 +
 2 files changed, 82 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index 7be771f820a53..d3186a0873434 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -120,10 +120,14 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
       else
         translatedPosition = transposePerm[i];
 
-      // Cannot fold in pack if a tile dimension was transposed with a non-tile
-      // dimension.
-      if (translatedPosition >= transposePerm.size() - packInnerTiles.size())
-        return failure();
+      // Note: static_cast was added around translatedPosition to suppress the
+      // compiler warning of comparison between variables of different types.
+      if (static_cast<unsigned long>(translatedPosition) >=
+          transposePerm.size() - packInnerTiles.size())
+        return rewriter.notifyMatchFailure(
+            transposeOp,
+            "Cannot fold in tensor.pack if a tile dimension was transposed "
+            "with a non-tile dimension in linalg.transpose.");
 
       newPackOuterDimsPermVec.push_back(translatedPosition);
     }
@@ -138,29 +142,28 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
       newPackInnerDimsPosVec.push_back(packInnerDimsPos[translatedPosition]);
     }
 
-    llvm::SmallVector<OpFoldResult, 4> opFoldResultsTiles;
+    SmallVector<OpFoldResult> opFoldResultsTiles;
     opFoldResultsTiles.reserve(newPackInnerTilesVec.size());
 
-    llvm::transform(
-        newPackInnerTilesVec, std::back_inserter(opFoldResultsTiles),
-        [&rewriter](int64_t value) {
-          return IntegerAttr::get(IndexType::get(rewriter.getContext()), value);
-        });
+    transform(newPackInnerTilesVec, std::back_inserter(opFoldResultsTiles),
+              [&rewriter](int64_t value) {
+                return IntegerAttr::get(IndexType::get(rewriter.getContext()),
+                                        value);
+              });
 
-    llvm::ArrayRef<OpFoldResult> newPackInnerTilesArrayRef(opFoldResultsTiles);
+    ArrayRef<OpFoldResult> newPackInnerTilesArrayRef(opFoldResultsTiles);
 
     Value output = packOp.createDestinationTensor(
         rewriter, transposeOp.getLoc(), packOp.getSource(),
         newPackInnerTilesArrayRef,
-        static_cast<llvm::ArrayRef<int64_t>>(newPackInnerDimsPosVec),
-        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+        static_cast<ArrayRef<int64_t>>(newPackInnerDimsPosVec),
+        static_cast<ArrayRef<int64_t>>(newPackOuterDimsPermVec));
 
     rewriter.replaceOpWithNewOp<PackOp>(
         transposeOp, packOp.getSource(), output,
-        static_cast<llvm::ArrayRef<int64_t>>(newPackInnerDimsPosVec),
-        newPackInnerTilesArrayRef,
-        /*paddingValue=*/std::nullopt,
-        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+        static_cast<ArrayRef<int64_t>>(newPackInnerDimsPosVec),
+        newPackInnerTilesArrayRef, packOp.getPaddingValue(),
+        static_cast<ArrayRef<int64_t>>(newPackOuterDimsPermVec));
 
     return success();
   }
@@ -180,22 +183,75 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
     if (!transposeOp)
       return failure();
 
+    auto packInnerDimsPos = packOp.getInnerDimsPos();
+    auto packInnerTiles = packOp.getStaticInnerTiles();
     auto packOuterDimsPerm = packOp.getOuterDimsPerm();
     auto transposePerm = transposeOp.getPermutation();
     SmallVector<int64_t> newPackOuterDimsPermVec;
+    SmallVector<int64_t> newPackInnerDimsPosVec;
+    SmallVector<int64_t> newPackInnerTilesVec;
+
+    // Variable for storing translated position after considering original
+    // outer_dims_perm and permutation attributes of tensor.pack and
+    // linalg.transpose.
+    int64_t translatedPosition;
 
-    for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
-      newPackOuterDimsPermVec.push_back(transposePerm[packOuterDimsPerm[i]]);
+    // Process transpose operation for non-tiled outer dimensions of the tensor.
+    for (unsigned int i = 0; i < transposePerm.size() - packInnerTiles.size();
+         ++i) {
+      // If tensor.pack has outer_dims_perm attribute, then consider it during
+      // index translation.
+      if (packOuterDimsPerm.size())
+        translatedPosition = packOuterDimsPerm[transposePerm[i]];
+      else
+        translatedPosition = transposePerm[i];
+
+      // Note: static_cast was added around translatedPosition to suppress the
+      // compiler warning of comparison between variables of different types.
+      if (static_cast<unsigned long>(translatedPosition) >=
+          transposePerm.size() - packInnerTiles.size())
+        return rewriter.notifyMatchFailure(
+            packOp,
+            "Cannot fold in tensor.pack if a tile dimension was transposed "
+            "with a non-tile dimension in linalg.transpose.");
+
+      newPackOuterDimsPermVec.push_back(translatedPosition);
+    }
+
+    // Process transpose operation for tiled inner dimensions of the tensor.
+    for (unsigned int i = transposePerm.size() - packInnerTiles.size();
+         i < transposePerm.size(); ++i) {
+      translatedPosition =
+          transposePerm[i] - (transposePerm.size() - packInnerTiles.size());
+
+      newPackInnerTilesVec.push_back(packInnerTiles[translatedPosition]);
+      newPackInnerDimsPosVec.push_back(packInnerDimsPos[translatedPosition]);
+    }
+
+    SmallVector<OpFoldResult> opFoldResultsTiles;
+    opFoldResultsTiles.reserve(newPackInnerTilesVec.size());
+
+    transform(newPackInnerTilesVec, std::back_inserter(opFoldResultsTiles),
+              [&rewriter](int64_t value) {
+                return IntegerAttr::get(IndexType::get(rewriter.getContext()),
+                                        value);
+              });
+
+    ArrayRef<OpFoldResult> newPackInnerTilesArrayRef(opFoldResultsTiles);
 
     Value output = packOp.createDestinationTensor(
         rewriter, packOp.getLoc(), transposeOp.getOperand(0),
-        packOp.getMixedTiles(), packOp.getInnerDimsPos(),
-        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+        newPackInnerTilesArrayRef, 
+        static_cast<ArrayRef<int64_t>>(newPackInnerDimsPosVec),
+        static_cast<ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+    output.dump();
 
     rewriter.replaceOpWithNewOp<PackOp>(
-        packOp, transposeOp.getOperand(0), output, packOp.getInnerDimsPos(),
-        packOp.getMixedTiles(), /*paddingValue=*/std::nullopt,
-        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+        packOp, transposeOp.getOperand(0), output, 
+        static_cast<ArrayRef<int64_t>>(newPackInnerDimsPosVec),
+        newPackInnerTilesArrayRef, packOp.getPaddingValue(),
+        static_cast<ArrayRef<int64_t>>(newPackOuterDimsPermVec));
 
     return success();
   }
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 0b00c7fa7feb9..2b10857172ab2 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -115,6 +115,8 @@ func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tenso
 // CHECK:         tensor.pad
 // CHECK:         tensor.pack
 
+// -----
+
 func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
   %0 = tensor.empty() : tensor<1x56x57x64xf32>
   %transposed = linalg.transpose

>From 5917857e1e7c1511c5108c30fa6e821115c84094 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Fri, 8 Dec 2023 20:27:20 +0000
Subject: [PATCH 8/8] Drop consumer pack producer transpose case and add more
 tests

---
 .../FoldIntoPackAndUnpackPatterns.cpp         | 115 ++----------------
 .../Tensor/fold-into-pack-and-unpack.mlir     |  83 ++++++++++---
 2 files changed, 78 insertions(+), 120 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index d3186a0873434..58d816aa8d83e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -115,20 +115,20 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
          ++i) {
       // If tensor.pack has outer_dims_perm attribute, then consider it during
       // index translation.
-      if (packOuterDimsPerm.size())
-        translatedPosition = packOuterDimsPerm[transposePerm[i]];
-      else
+      if (packOuterDimsPerm.size()) {
+        // Note: static_cast is added around transposePerm[i] to suppress the
+        // compiler warning of comparison between variables of different types.
+        if (static_cast<unsigned long>(transposePerm[i]) <
+            packOuterDimsPerm.size())
+          translatedPosition = packOuterDimsPerm[transposePerm[i]];
+        else
+          return rewriter.notifyMatchFailure(
+              transposeOp,
+              "Cannot fold in tensor.pack if a tile dimension was transposed "
+              "with a non-tile dimension in linalg.transpose.");
+      } else
         translatedPosition = transposePerm[i];
 
-      // Note: static_cast was added around translatedPosition to suppress the
-      // compiler warning of comparison between variables of different types.
-      if (static_cast<unsigned long>(translatedPosition) >=
-          transposePerm.size() - packInnerTiles.size())
-        return rewriter.notifyMatchFailure(
-            transposeOp,
-            "Cannot fold in tensor.pack if a tile dimension was transposed "
-            "with a non-tile dimension in linalg.transpose.");
-
       newPackOuterDimsPermVec.push_back(translatedPosition);
     }
 
@@ -168,100 +168,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     return success();
   }
 };
-
-/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
-/// semantics.
-struct FoldConsumerPackWithProducerLinalgTransposeOp
-    : public OpRewritePattern<PackOp> {
-  using OpRewritePattern<PackOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(PackOp packOp,
-                                PatternRewriter &rewriter) const override {
-    auto packInputTensor = packOp.getOperand(0);
-    auto transposeOp = packInputTensor.getDefiningOp<linalg::TransposeOp>();
-
-    if (!transposeOp)
-      return failure();
-
-    auto packInnerDimsPos = packOp.getInnerDimsPos();
-    auto packInnerTiles = packOp.getStaticInnerTiles();
-    auto packOuterDimsPerm = packOp.getOuterDimsPerm();
-    auto transposePerm = transposeOp.getPermutation();
-    SmallVector<int64_t> newPackOuterDimsPermVec;
-    SmallVector<int64_t> newPackInnerDimsPosVec;
-    SmallVector<int64_t> newPackInnerTilesVec;
-
-    // Variable for storing translated position after considering original
-    // outer_dims_perm and permutation attributes of tensor.pack and
-    // linalg.transpose.
-    int64_t translatedPosition;
-
-    // Process transpose operation for non-tiled outer dimensions of the tensor.
-    for (unsigned int i = 0; i < transposePerm.size() - packInnerTiles.size();
-         ++i) {
-      // If tensor.pack has outer_dims_perm attribute, then consider it during
-      // index translation.
-      if (packOuterDimsPerm.size())
-        translatedPosition = packOuterDimsPerm[transposePerm[i]];
-      else
-        translatedPosition = transposePerm[i];
-
-      // Note: static_cast was added around translatedPosition to suppress the
-      // compiler warning of comparison between variables of different types.
-      if (static_cast<unsigned long>(translatedPosition) >=
-          transposePerm.size() - packInnerTiles.size())
-        return rewriter.notifyMatchFailure(
-            packOp,
-            "Cannot fold in tensor.pack if a tile dimension was transposed "
-            "with a non-tile dimension in linalg.transpose.");
-
-      newPackOuterDimsPermVec.push_back(translatedPosition);
-    }
-
-    // Process transpose operation for tiled inner dimensions of the tensor.
-    for (unsigned int i = transposePerm.size() - packInnerTiles.size();
-         i < transposePerm.size(); ++i) {
-      translatedPosition =
-          transposePerm[i] - (transposePerm.size() - packInnerTiles.size());
-
-      newPackInnerTilesVec.push_back(packInnerTiles[translatedPosition]);
-      newPackInnerDimsPosVec.push_back(packInnerDimsPos[translatedPosition]);
-    }
-
-    SmallVector<OpFoldResult> opFoldResultsTiles;
-    opFoldResultsTiles.reserve(newPackInnerTilesVec.size());
-
-    transform(newPackInnerTilesVec, std::back_inserter(opFoldResultsTiles),
-              [&rewriter](int64_t value) {
-                return IntegerAttr::get(IndexType::get(rewriter.getContext()),
-                                        value);
-              });
-
-    ArrayRef<OpFoldResult> newPackInnerTilesArrayRef(opFoldResultsTiles);
-
-    Value output = packOp.createDestinationTensor(
-        rewriter, packOp.getLoc(), transposeOp.getOperand(0),
-        newPackInnerTilesArrayRef, 
-        static_cast<ArrayRef<int64_t>>(newPackInnerDimsPosVec),
-        static_cast<ArrayRef<int64_t>>(newPackOuterDimsPermVec));
-
-    output.dump();
-
-    rewriter.replaceOpWithNewOp<PackOp>(
-        packOp, transposeOp.getOperand(0), output, 
-        static_cast<ArrayRef<int64_t>>(newPackInnerDimsPosVec),
-        newPackInnerTilesArrayRef, packOp.getPaddingValue(),
-        static_cast<ArrayRef<int64_t>>(newPackOuterDimsPermVec));
-
-    return success();
-  }
-};
 } // namespace
 
 void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
   patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
-                  FoldProducerPackWithConsumerLinalgTransposeOp,
-                  FoldConsumerPackWithProducerLinalgTransposeOp>(
+                  FoldProducerPackWithConsumerLinalgTransposeOp>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 2b10857172ab2..70c78aa6cda1c 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -117,36 +117,35 @@ func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tenso
 
 // -----
 
-func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
-  %0 = tensor.empty() : tensor<1x56x57x64xf32>
-  %transposed = linalg.transpose
-    ins(%arg0 : tensor<56x57x1x64xf32>)
-    outs(%0 : tensor<1x56x57x64xf32>)
-    permutation = [2, 0, 1, 3]
-
-  %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
-  %pack = tensor.pack %transposed
-    outer_dims_perm = [0, 3, 1, 2]
+func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x57x56x2x32xf32> {
+  %0 = tensor.empty() : tensor<56x2x1x57x32xf32>
+  %pack = tensor.pack %arg0
+    outer_dims_perm = [0, 3, 2, 1]
     inner_dims_pos = [3]
     inner_tiles = [32]
-    into %1 : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
-  return %pack : tensor<1x2x56x57x32xf32>
+    into %0 : tensor<56x57x1x64xf32> -> tensor<56x2x1x57x32xf32>
+
+  %1 = tensor.empty() : tensor<1x57x56x2x32xf32>
+  %transposed = linalg.transpose
+    ins(%pack : tensor<56x2x1x57x32xf32>)
+    outs(%1 : tensor<1x57x56x2x32xf32>)
+    permutation = [2, 3, 0, 1, 4]
+  return %transposed : tensor<1x57x56x2x32xf32>
 }
-//      CHECK: func @linalg_transpose_tensor_pack_fold(
+//      CHECK: func @tensor_pack_linalg_transpose_fold(
 // CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
-//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
 //      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:      outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME:      outer_dims_perm = [2, 1, 0, 3]
 // CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
 // CHECK-SAME:       into %[[INIT]]
 //      CHECK:   return %[[PACK]]
 
 // -----
 
-func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+func.func @tensor_pack_linalg_transpose_fold_no_outer_dims_perm(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
   %0 = tensor.empty() : tensor<56x57x1x2x32xf32>
   %pack = tensor.pack %arg0
-    outer_dims_perm = [0, 1, 2, 3]
     inner_dims_pos = [3]
     inner_tiles = [32]
     into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
@@ -158,7 +157,7 @@ func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> t
     permutation = [2, 3, 0, 1, 4]
   return %transposed : tensor<1x2x56x57x32xf32>
 }
-//      CHECK: func @tensor_pack_linalg_transpose_fold(
+//      CHECK: func @tensor_pack_linalg_transpose_fold_no_outer_dims_perm(
 // CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
 //      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
 //      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
@@ -166,3 +165,51 @@ func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> t
 // CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
 // CHECK-SAME:       into %[[INIT]]
 //      CHECK:   return %[[PACK]]
+
+// -----
+
+func.func @tensor_pack_linalg_transpose_fold_tile_dims_transpose(%arg0: tensor<56x64x4x64xf32>) -> tensor<2x2x56x2x32x32x2xf32> {
+  %0 = tensor.empty() : tensor<56x2x2x2x32x2x32xf32>
+  %pack = tensor.pack %arg0
+    outer_dims_perm = [0, 1, 2, 3]
+    inner_dims_pos = [1, 2, 3]
+    inner_tiles = [32, 2, 32]
+    into %0 : tensor<56x64x4x64xf32> -> tensor<56x2x2x2x32x2x32xf32>
+
+  %1 = tensor.empty() : tensor<2x2x56x2x32x32x2xf32>
+  %transposed = linalg.transpose
+    ins(%pack : tensor<56x2x2x2x32x2x32xf32>)
+    outs(%1 : tensor<2x2x56x2x32x32x2xf32>)
+    permutation = [2, 3, 0, 1, 6, 4, 5]
+  return %transposed : tensor<2x2x56x2x32x32x2xf32>
+}
+//      CHECK: func @tensor_pack_linalg_transpose_fold_tile_dims_transpose(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x64x4x64xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<2x2x56x2x32x32x2xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME:      inner_dims_pos = [3, 1, 2] inner_tiles = [32, 32, 2] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]
+
+// -----
+
+func.func @tensor_pack_linalg_transpose_fold_tile_dims_outer_dims_transpose(%arg0: tensor<56x64x4x64xf32>) -> tensor<2x2x2x2x32x32x56xf32> {
+  %0 = tensor.empty() : tensor<56x2x2x2x32x2x32xf32>
+  %pack = tensor.pack %arg0
+    outer_dims_perm = [0, 1, 2, 3]
+    inner_dims_pos = [1, 2, 3]
+    inner_tiles = [32, 2, 32]
+    into %0 : tensor<56x64x4x64xf32> -> tensor<56x2x2x2x32x2x32xf32>
+
+  %1 = tensor.empty() : tensor<2x2x2x2x32x32x56xf32>
+  %transposed = linalg.transpose
+    ins(%pack : tensor<56x2x2x2x32x2x32xf32>)
+    outs(%1 : tensor<2x2x2x2x32x32x56xf32>)
+    permutation = [2, 3, 5, 1, 6, 4, 0]
+  return %transposed : tensor<2x2x2x2x32x32x56xf32>
+}
+//      CHECK: func @tensor_pack_linalg_transpose_fold_tile_dims_outer_dims_transpose(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x64x4x64xf32>)
+//      CHECK:   tensor.pack
+//      CHECK:   linalg.transpose



More information about the llvm-commits mailing list