[Mlir-commits] [mlir] [llvm] [mlir][tensor] Fold linalg transpose with tensor pack (PR #74206)

Prathamesh Tagore llvmlistbot at llvm.org
Sat Dec 2 12:13:06 PST 2023


https://github.com/meshtag created https://github.com/llvm/llvm-project/pull/74206

Partial fix to https://github.com/openxla/iree/issues/15367

>From febd3de154543df1a4828994bfe716fee5cf4e52 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Thu, 30 Nov 2023 21:44:17 +0000
Subject: [PATCH 1/4] First working draft for both cases

---
 .../FoldIntoPackAndUnpackPatterns.cpp         |  81 +++++++++++-
 .../Tensor/fold-into-pack-and-unpack.mlir     | 121 +++++-------------
 result.mlir                                   |   0
 3 files changed, 110 insertions(+), 92 deletions(-)
 create mode 100644 result.mlir

diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index 9eac3e5c7ef91..ca69341afb51c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -6,11 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
 #include "llvm/Support/Debug.h"
 
+#include <iostream>
+
 namespace mlir {
 namespace tensor {
 namespace {
@@ -81,10 +84,86 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
     return success();
   }
 };
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+    : public OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto transposeInputTensor = transposeOp.getOperand(0);
+    auto packOp = transposeInputTensor.getDefiningOp<PackOp>();
+
+    if (!packOp)
+      return failure();
+
+    auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+    auto transposePerm = transposeOp.getPermutation();
+    llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+    for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+      newPackOuterDimsPermVec.push_back(packOuterDimsPerm[transposePerm[i]]);
+
+    // Create a new empty output tensor.
+    Type elementType = packOp.getDestType().getElementType();
+    auto transposeOpResultType = transposeOp.getResult().getType()[0];
+    auto rankedTensorType = transposeOpResultType.dyn_cast<RankedTensorType>();
+    Value output = rewriter.create<EmptyOp>(
+        transposeOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+    rewriter.replaceOpWithNewOp<PackOp>(
+        transposeOp, packOp.getSource(), output, packOp.getInnerDimsPos(),
+        packOp.getMixedTiles(), std::nullopt,
+        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+    return success();
+  }
+};
+
+/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldConsumerPackWithProducerLinalgTransposeOp
+    : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto packInputTensor = packOp.getOperand(0);
+    auto transposeOp = packInputTensor.getDefiningOp<linalg::TransposeOp>();
+
+    if (!transposeOp)
+      return failure();
+
+    auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+    auto transposePerm = transposeOp.getPermutation();
+    llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+    for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+      newPackOuterDimsPermVec.push_back(transposePerm[packOuterDimsPerm[i]]);
+
+    // Create a new empty output tensor.
+    Type elementType = packOp.getDestType().getElementType();
+    auto packOpResultType = packOp.getResult().getType();
+    auto rankedTensorType = packOpResultType.dyn_cast<RankedTensorType>();
+    Value output = rewriter.create<EmptyOp>(
+        packOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+    rewriter.replaceOpWithNewOp<PackOp>(
+        packOp, transposeOp.getOperand(0), output, packOp.getInnerDimsPos(),
+        packOp.getMixedTiles(), std::nullopt,
+        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+    return success();
+  }
+};
 } // namespace
 
 void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
-  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp>(
+  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
+                  FoldProducerPackWithConsumerLinalgTransposeOp,
+                  FoldConsumerPackWithProducerLinalgTransposeOp>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 5c75789665742..bc67b75480383 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -17,100 +17,39 @@ func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32
 // CHECK-SAME:       into %[[INIT]]
 //      CHECK:   return %[[UNPACK]]
 
-// -----
-
-func.func @nofold_unpack_slice_non_zero_offset(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
-    %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
-  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
-      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
-  %1 = tensor.extract_slice %0[0, %arg4] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
-}
-// CHECK-LABEL: func @nofold_unpack_slice_non_zero_offset(
-//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
-//       CHECK:   tensor.extract_slice %[[UNPACK]]
-
-// -----
-
-func.func @nofold_unpack_slice_non_unit_stride(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
-    %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
-  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
-      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
-  %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [%arg4, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
+func.func @foo(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+  %0 = tensor.empty() : tensor<1x56x57x64xf32>
+  %transposed = linalg.transpose
+    ins(%arg0 : tensor<56x57x1x64xf32>)
+    outs(%0 : tensor<1x56x57x64xf32>)
+    permutation = [2, 0, 1, 3]
+  %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+
+  // [2, 3, 0, 1]
+
+  %pack = tensor.pack %transposed
+    outer_dims_perm = [0, 3, 1, 2]
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %1 : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
+  return %pack : tensor<1x2x56x57x32xf32>
 }
-// CHECK-LABEL: func @nofold_unpack_slice_non_unit_stride(
-//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
-//       CHECK:   tensor.extract_slice %[[UNPACK]]
 
-// -----
 
-func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
-    %arg2 : index, %arg3 : index) -> tensor<f32> {
-  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
-      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
-  %1 = tensor.extract_slice %0[0, 0] [1, 1] [1, 1] : tensor<?x?xf32> to tensor<f32>
-  return %1 : tensor<f32>
-}
-// CHECK-LABEL: func @nofold_unpack_slice_rank_reduced(
-//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
-//       CHECK:   tensor.extract_slice %[[UNPACK]]
-
-// -----
-
-func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.000000e+00 : f32
-  %padded = tensor.pad %src low[0, 0] high[15, 0] {
-  ^bb0(%arg0: index, %arg1: index):
-    tensor.yield %cst : f32
-  } : tensor<16641x16xf32> to tensor<16656x16xf32>
-  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
-  %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
-      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
-  return %pack : tensor<2082x1x8x32xf32>
-}
-// CHECK-LABEL: func.func @pad_pack
-// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
-// CHECK:         %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32>
-// CHECK:         %[[PACK:.+]] = tensor.pack %[[SRC]]
-// CHECK-SAME:      padding_value(%[[PAD_VAL]] : f32)
-// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]]
-
-// -----
-
-func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.000000e+00 : f32
-  %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
-  ^bb0(%arg0: index, %arg1: index):
-    tensor.yield %cst : f32
-  } : tensor<16641x16xf32> to tensor<16656x16xf32>
-  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
-  %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
-      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
-  return %pack : tensor<2082x1x8x32xf32>
-}
-// CHECK-LABEL: func.func @nofold_pad_pack
-// CHECK:         tensor.pad
-// CHECK:         tensor.pack
+func.func @foo1(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+  %0 = tensor.empty() : tensor<56x57x1x2x32xf32>
+  %pack = tensor.pack %arg0
+    outer_dims_perm = [0, 1, 2, 3]
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
 
-// -----
+    // [2, 3, 0, 1]
 
-func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
-  %c0 = arith.constant 0 : index
-  %cst0 = arith.constant 0.000000e+00 : f32
-  %cst1 = arith.constant 1.000000e+00 : f32
-  %padded = tensor.pad %src low[0, 0] high[15, 0] {
-  ^bb0(%arg0: index, %arg1: index):
-    tensor.yield %cst0 : f32
-  } : tensor<16641x16xf32> to tensor<16656x16xf32>
-  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
-  %pack = tensor.pack %padded padding_value(%cst1 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
-      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
-  return %pack : tensor<2082x1x8x32xf32>
+    %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+    %transposed = linalg.transpose
+    ins(%pack : tensor<56x57x1x2x32xf32>)
+    outs(%1 : tensor<1x2x56x57x32xf32>)
+    permutation = [2, 3, 0, 1, 4]
+  return %transposed : tensor<1x2x56x57x32xf32>
 }
-// CHECK-LABEL: func.func @pad_pack_different_padding_value
-// CHECK:         tensor.pad
-// CHECK:         tensor.pack
diff --git a/result.mlir b/result.mlir
new file mode 100644
index 0000000000000..e69de29bb2d1d

>From 57e33eb73e6ff6d498d0cba511419b4245ca3399 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Sat, 2 Dec 2023 19:59:54 +0000
Subject: [PATCH 2/4] Add tests

---
 .../Tensor/fold-into-pack-and-unpack.mlir     | 125 +++++++++++++++++-
 1 file changed, 118 insertions(+), 7 deletions(-)

diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index bc67b75480383..b9b37680cdf47 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -17,7 +17,105 @@ func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32
 // CHECK-SAME:       into %[[INIT]]
 //      CHECK:   return %[[UNPACK]]
 
-func.func @foo(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+// -----
+
+func.func @nofold_unpack_slice_non_zero_offset(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0[0, %arg4] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_non_zero_offset(
+//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
+//       CHECK:   tensor.extract_slice %[[UNPACK]]
+
+// -----
+
+func.func @nofold_unpack_slice_non_unit_stride(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [%arg4, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_non_unit_stride(
+//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
+//       CHECK:   tensor.extract_slice %[[UNPACK]]
+
+// -----
+
+func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : index, %arg3 : index) -> tensor<f32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0[0, 0] [1, 1] [1, 1] : tensor<?x?xf32> to tensor<f32>
+  return %1 : tensor<f32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_rank_reduced(
+//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
+//       CHECK:   tensor.extract_slice %[[UNPACK]]
+
+// -----
+
+func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %src low[0, 0] high[15, 0] {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst : f32
+  } : tensor<16641x16xf32> to tensor<16656x16xf32>
+  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+  %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+  return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @pad_pack
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK:         %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[SRC]]
+// CHECK-SAME:      padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]]
+
+// -----
+
+func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst : f32
+  } : tensor<16641x16xf32> to tensor<16656x16xf32>
+  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+  %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+  return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @nofold_pad_pack
+// CHECK:         tensor.pad
+// CHECK:         tensor.pack
+
+// -----
+
+func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+  %c0 = arith.constant 0 : index
+  %cst0 = arith.constant 0.000000e+00 : f32
+  %cst1 = arith.constant 1.000000e+00 : f32
+  %padded = tensor.pad %src low[0, 0] high[15, 0] {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst0 : f32
+  } : tensor<16641x16xf32> to tensor<16656x16xf32>
+  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+  %pack = tensor.pack %padded padding_value(%cst1 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+  return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @pad_pack_different_padding_value
+// CHECK:         tensor.pad
+// CHECK:         tensor.pack
+
+func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
   %0 = tensor.empty() : tensor<1x56x57x64xf32>
   %transposed = linalg.transpose
     ins(%arg0 : tensor<56x57x1x64xf32>)
@@ -25,8 +123,6 @@ func.func @foo(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
     permutation = [2, 0, 1, 3]
   %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
 
-  // [2, 3, 0, 1]
-
   %pack = tensor.pack %transposed
     outer_dims_perm = [0, 3, 1, 2]
     inner_dims_pos = [3]
@@ -34,9 +130,18 @@ func.func @foo(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
     into %1 : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
   return %pack : tensor<1x2x56x57x32xf32>
 }
+//      CHECK: func @linalg_transpose_tensor_pack_fold(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]
 
+// -----
 
-func.func @foo1(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
   %0 = tensor.empty() : tensor<56x57x1x2x32xf32>
   %pack = tensor.pack %arg0
     outer_dims_perm = [0, 1, 2, 3]
@@ -44,12 +149,18 @@ func.func @foo1(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
     inner_tiles = [32]
     into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
 
-    // [2, 3, 0, 1]
-
-    %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+   %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
     %transposed = linalg.transpose
     ins(%pack : tensor<56x57x1x2x32xf32>)
     outs(%1 : tensor<1x2x56x57x32xf32>)
     permutation = [2, 3, 0, 1, 4]
   return %transposed : tensor<1x2x56x57x32xf32>
 }
+//      CHECK: func @tensor_pack_linalg_transpose_fold(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]

>From 354a61dcdc0226b3d1aaba263e7f2b41b325e96f Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Sat, 2 Dec 2023 20:02:12 +0000
Subject: [PATCH 3/4] Remove unnecessary stuff

---
 .../Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp | 2 --
 result.mlir                                                     | 0
 2 files changed, 2 deletions(-)
 delete mode 100644 result.mlir

diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index ca69341afb51c..47d85a6f4f9a5 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -12,8 +12,6 @@
 #include "mlir/IR/PatternMatch.h"
 #include "llvm/Support/Debug.h"
 
-#include <iostream>
-
 namespace mlir {
 namespace tensor {
 namespace {
diff --git a/result.mlir b/result.mlir
deleted file mode 100644
index e69de29bb2d1d..0000000000000

>From 1b9d96d0a111d863a7cc6e66837e55a11f8300bd Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Sat, 2 Dec 2023 20:09:56 +0000
Subject: [PATCH 4/4] Rectify formatting

---
 mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index b9b37680cdf47..0b00c7fa7feb9 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -121,8 +121,8 @@ func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> t
     ins(%arg0 : tensor<56x57x1x64xf32>)
     outs(%0 : tensor<1x56x57x64xf32>)
     permutation = [2, 0, 1, 3]
-  %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
 
+  %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
   %pack = tensor.pack %transposed
     outer_dims_perm = [0, 3, 1, 2]
     inner_dims_pos = [3]
@@ -149,8 +149,8 @@ func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> t
     inner_tiles = [32]
     into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
 
-   %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
-    %transposed = linalg.transpose
+  %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+  %transposed = linalg.transpose
     ins(%pack : tensor<56x57x1x2x32xf32>)
     outs(%1 : tensor<1x2x56x57x32xf32>)
     permutation = [2, 3, 0, 1, 4]



More information about the Mlir-commits mailing list