[mlir] [llvm] [mlir][tensor] Fold linalg transpose with tensor pack (PR #74206)
Prathamesh Tagore via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 8 09:55:42 PST 2023
https://github.com/meshtag updated https://github.com/llvm/llvm-project/pull/74206
>From febd3de154543df1a4828994bfe716fee5cf4e52 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Thu, 30 Nov 2023 21:44:17 +0000
Subject: [PATCH 1/6] First working draft for both cases
---
.../FoldIntoPackAndUnpackPatterns.cpp | 81 +++++++++++-
.../Tensor/fold-into-pack-and-unpack.mlir | 121 +++++-------------
result.mlir | 0
3 files changed, 110 insertions(+), 92 deletions(-)
create mode 100644 result.mlir
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index 9eac3e5c7ef910..ca69341afb51c8 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -6,11 +6,14 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
+#include <iostream>
+
namespace mlir {
namespace tensor {
namespace {
@@ -81,10 +84,86 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
return success();
}
};
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+ : public OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto transposeInputTensor = transposeOp.getOperand(0);
+ auto packOp = transposeInputTensor.getDefiningOp<PackOp>();
+
+ if (!packOp)
+ return failure();
+
+ auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+ auto transposePerm = transposeOp.getPermutation();
+ llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+ for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+ newPackOuterDimsPermVec.push_back(packOuterDimsPerm[transposePerm[i]]);
+
+ // Create a new empty output tensor.
+ Type elementType = packOp.getDestType().getElementType();
+ auto transposeOpResultType = transposeOp.getResult().getType()[0];
+ auto rankedTensorType = transposeOpResultType.dyn_cast<RankedTensorType>();
+ Value output = rewriter.create<EmptyOp>(
+ transposeOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+ rewriter.replaceOpWithNewOp<PackOp>(
+ transposeOp, packOp.getSource(), output, packOp.getInnerDimsPos(),
+ packOp.getMixedTiles(), std::nullopt,
+ static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+ return success();
+ }
+};
+
+/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldConsumerPackWithProducerLinalgTransposeOp
+ : public OpRewritePattern<PackOp> {
+ using OpRewritePattern<PackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ auto packInputTensor = packOp.getOperand(0);
+ auto transposeOp = packInputTensor.getDefiningOp<linalg::TransposeOp>();
+
+ if (!transposeOp)
+ return failure();
+
+ auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+ auto transposePerm = transposeOp.getPermutation();
+ llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+ for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+ newPackOuterDimsPermVec.push_back(transposePerm[packOuterDimsPerm[i]]);
+
+ // Create a new empty output tensor.
+ Type elementType = packOp.getDestType().getElementType();
+ auto packOpResultType = packOp.getResult().getType();
+ auto rankedTensorType = packOpResultType.dyn_cast<RankedTensorType>();
+ Value output = rewriter.create<EmptyOp>(
+ packOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+ rewriter.replaceOpWithNewOp<PackOp>(
+ packOp, transposeOp.getOperand(0), output, packOp.getInnerDimsPos(),
+ packOp.getMixedTiles(), std::nullopt,
+ static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+ return success();
+ }
+};
} // namespace
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
- patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp>(
+ patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
+ FoldProducerPackWithConsumerLinalgTransposeOp,
+ FoldConsumerPackWithProducerLinalgTransposeOp>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 5c757896657427..bc67b754803837 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -17,100 +17,39 @@ func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32
// CHECK-SAME: into %[[INIT]]
// CHECK: return %[[UNPACK]]
-// -----
-
-func.func @nofold_unpack_slice_non_zero_offset(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
- %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
- %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
- : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
- %1 = tensor.extract_slice %0[0, %arg4] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- return %1 : tensor<?x?xf32>
-}
-// CHECK-LABEL: func @nofold_unpack_slice_non_zero_offset(
-// CHECK: %[[UNPACK:.+]] = tensor.unpack
-// CHECK: tensor.extract_slice %[[UNPACK]]
-
-// -----
-
-func.func @nofold_unpack_slice_non_unit_stride(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
- %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
- %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
- : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
- %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [%arg4, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- return %1 : tensor<?x?xf32>
+func.func @foo(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+ %0 = tensor.empty() : tensor<1x56x57x64xf32>
+ %transposed = linalg.transpose
+ ins(%arg0 : tensor<56x57x1x64xf32>)
+ outs(%0 : tensor<1x56x57x64xf32>)
+ permutation = [2, 0, 1, 3]
+ %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+
+ // [2, 3, 0, 1]
+
+ %pack = tensor.pack %transposed
+ outer_dims_perm = [0, 3, 1, 2]
+ inner_dims_pos = [3]
+ inner_tiles = [32]
+ into %1 : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
+ return %pack : tensor<1x2x56x57x32xf32>
}
-// CHECK-LABEL: func @nofold_unpack_slice_non_unit_stride(
-// CHECK: %[[UNPACK:.+]] = tensor.unpack
-// CHECK: tensor.extract_slice %[[UNPACK]]
-// -----
-func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
- %arg2 : index, %arg3 : index) -> tensor<f32> {
- %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
- : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
- %1 = tensor.extract_slice %0[0, 0] [1, 1] [1, 1] : tensor<?x?xf32> to tensor<f32>
- return %1 : tensor<f32>
-}
-// CHECK-LABEL: func @nofold_unpack_slice_rank_reduced(
-// CHECK: %[[UNPACK:.+]] = tensor.unpack
-// CHECK: tensor.extract_slice %[[UNPACK]]
-
-// -----
-
-func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.000000e+00 : f32
- %padded = tensor.pad %src low[0, 0] high[15, 0] {
- ^bb0(%arg0: index, %arg1: index):
- tensor.yield %cst : f32
- } : tensor<16641x16xf32> to tensor<16656x16xf32>
- %empty = tensor.empty() : tensor<2082x1x8x32xf32>
- %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
- : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
- return %pack : tensor<2082x1x8x32xf32>
-}
-// CHECK-LABEL: func.func @pad_pack
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
-// CHECK: %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32>
-// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]]
-// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]]
-
-// -----
-
-func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.000000e+00 : f32
- %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
- ^bb0(%arg0: index, %arg1: index):
- tensor.yield %cst : f32
- } : tensor<16641x16xf32> to tensor<16656x16xf32>
- %empty = tensor.empty() : tensor<2082x1x8x32xf32>
- %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
- : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
- return %pack : tensor<2082x1x8x32xf32>
-}
-// CHECK-LABEL: func.func @nofold_pad_pack
-// CHECK: tensor.pad
-// CHECK: tensor.pack
+func.func @foo1(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+ %0 = tensor.empty() : tensor<56x57x1x2x32xf32>
+ %pack = tensor.pack %arg0
+ outer_dims_perm = [0, 1, 2, 3]
+ inner_dims_pos = [3]
+ inner_tiles = [32]
+ into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
-// -----
+ // [2, 3, 0, 1]
-func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
- %c0 = arith.constant 0 : index
- %cst0 = arith.constant 0.000000e+00 : f32
- %cst1 = arith.constant 1.000000e+00 : f32
- %padded = tensor.pad %src low[0, 0] high[15, 0] {
- ^bb0(%arg0: index, %arg1: index):
- tensor.yield %cst0 : f32
- } : tensor<16641x16xf32> to tensor<16656x16xf32>
- %empty = tensor.empty() : tensor<2082x1x8x32xf32>
- %pack = tensor.pack %padded padding_value(%cst1 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
- : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
- return %pack : tensor<2082x1x8x32xf32>
+ %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+ %transposed = linalg.transpose
+ ins(%pack : tensor<56x57x1x2x32xf32>)
+ outs(%1 : tensor<1x2x56x57x32xf32>)
+ permutation = [2, 3, 0, 1, 4]
+ return %transposed : tensor<1x2x56x57x32xf32>
}
-// CHECK-LABEL: func.func @pad_pack_different_padding_value
-// CHECK: tensor.pad
-// CHECK: tensor.pack
diff --git a/result.mlir b/result.mlir
new file mode 100644
index 00000000000000..e69de29bb2d1d6
>From 57e33eb73e6ff6d498d0cba511419b4245ca3399 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Sat, 2 Dec 2023 19:59:54 +0000
Subject: [PATCH 2/6] Add tests
---
.../Tensor/fold-into-pack-and-unpack.mlir | 125 +++++++++++++++++-
1 file changed, 118 insertions(+), 7 deletions(-)
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index bc67b754803837..b9b37680cdf477 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -17,7 +17,105 @@ func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32
// CHECK-SAME: into %[[INIT]]
// CHECK: return %[[UNPACK]]
-func.func @foo(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+// -----
+
+func.func @nofold_unpack_slice_non_zero_offset(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+ : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+ %1 = tensor.extract_slice %0[0, %arg4] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_non_zero_offset(
+// CHECK: %[[UNPACK:.+]] = tensor.unpack
+// CHECK: tensor.extract_slice %[[UNPACK]]
+
+// -----
+
+func.func @nofold_unpack_slice_non_unit_stride(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+ : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+ %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [%arg4, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_non_unit_stride(
+// CHECK: %[[UNPACK:.+]] = tensor.unpack
+// CHECK: tensor.extract_slice %[[UNPACK]]
+
+// -----
+
+func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : index, %arg3 : index) -> tensor<f32> {
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+ : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+ %1 = tensor.extract_slice %0[0, 0] [1, 1] [1, 1] : tensor<?x?xf32> to tensor<f32>
+ return %1 : tensor<f32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_rank_reduced(
+// CHECK: %[[UNPACK:.+]] = tensor.unpack
+// CHECK: tensor.extract_slice %[[UNPACK]]
+
+// -----
+
+func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %padded = tensor.pad %src low[0, 0] high[15, 0] {
+ ^bb0(%arg0: index, %arg1: index):
+ tensor.yield %cst : f32
+ } : tensor<16641x16xf32> to tensor<16656x16xf32>
+ %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+ %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+ : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+ return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @pad_pack
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK: %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]]
+// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]]
+
+// -----
+
+func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
+ ^bb0(%arg0: index, %arg1: index):
+ tensor.yield %cst : f32
+ } : tensor<16641x16xf32> to tensor<16656x16xf32>
+ %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+ %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+ : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+ return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @nofold_pad_pack
+// CHECK: tensor.pad
+// CHECK: tensor.pack
+
+// -----
+
+func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+ %c0 = arith.constant 0 : index
+ %cst0 = arith.constant 0.000000e+00 : f32
+ %cst1 = arith.constant 1.000000e+00 : f32
+ %padded = tensor.pad %src low[0, 0] high[15, 0] {
+ ^bb0(%arg0: index, %arg1: index):
+ tensor.yield %cst0 : f32
+ } : tensor<16641x16xf32> to tensor<16656x16xf32>
+ %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+ %pack = tensor.pack %padded padding_value(%cst1 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+ : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+ return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @pad_pack_different_padding_value
+// CHECK: tensor.pad
+// CHECK: tensor.pack
+
+func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
%0 = tensor.empty() : tensor<1x56x57x64xf32>
%transposed = linalg.transpose
ins(%arg0 : tensor<56x57x1x64xf32>)
@@ -25,8 +123,6 @@ func.func @foo(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
permutation = [2, 0, 1, 3]
%1 = tensor.empty() : tensor<1x2x56x57x32xf32>
- // [2, 3, 0, 1]
-
%pack = tensor.pack %transposed
outer_dims_perm = [0, 3, 1, 2]
inner_dims_pos = [3]
@@ -34,9 +130,18 @@ func.func @foo(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
into %1 : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
return %pack : tensor<1x2x56x57x32xf32>
}
+// CHECK: func @linalg_transpose_tensor_pack_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[INIT]]
+// CHECK: return %[[PACK]]
+// -----
-func.func @foo1(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
%0 = tensor.empty() : tensor<56x57x1x2x32xf32>
%pack = tensor.pack %arg0
outer_dims_perm = [0, 1, 2, 3]
@@ -44,12 +149,18 @@ func.func @foo1(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
inner_tiles = [32]
into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
- // [2, 3, 0, 1]
-
- %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+ %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
%transposed = linalg.transpose
ins(%pack : tensor<56x57x1x2x32xf32>)
outs(%1 : tensor<1x2x56x57x32xf32>)
permutation = [2, 3, 0, 1, 4]
return %transposed : tensor<1x2x56x57x32xf32>
}
+// CHECK: func @tensor_pack_linalg_transpose_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[INIT]]
+// CHECK: return %[[PACK]]
>From 354a61dcdc0226b3d1aaba263e7f2b41b325e96f Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Sat, 2 Dec 2023 20:02:12 +0000
Subject: [PATCH 3/6] Remove unnecessary stuff
---
.../Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp | 2 --
result.mlir | 0
2 files changed, 2 deletions(-)
delete mode 100644 result.mlir
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index ca69341afb51c8..47d85a6f4f9a59 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -12,8 +12,6 @@
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
-#include <iostream>
-
namespace mlir {
namespace tensor {
namespace {
diff --git a/result.mlir b/result.mlir
deleted file mode 100644
index e69de29bb2d1d6..00000000000000
>From 1b9d96d0a111d863a7cc6e66837e55a11f8300bd Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Sat, 2 Dec 2023 20:09:56 +0000
Subject: [PATCH 4/6] Rectify formatting
---
mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index b9b37680cdf477..0b00c7fa7feb99 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -121,8 +121,8 @@ func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> t
ins(%arg0 : tensor<56x57x1x64xf32>)
outs(%0 : tensor<1x56x57x64xf32>)
permutation = [2, 0, 1, 3]
- %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+ %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
%pack = tensor.pack %transposed
outer_dims_perm = [0, 3, 1, 2]
inner_dims_pos = [3]
@@ -149,8 +149,8 @@ func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> t
inner_tiles = [32]
into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
- %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
- %transposed = linalg.transpose
+ %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+ %transposed = linalg.transpose
ins(%pack : tensor<56x57x1x2x32xf32>)
outs(%1 : tensor<1x2x56x57x32xf32>)
permutation = [2, 3, 0, 1, 4]
>From 840ba3ee5063aa7b4633e681d412786ee37f51b4 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Mon, 4 Dec 2023 14:44:50 +0000
Subject: [PATCH 5/6] Apply first batch of suggested changes
---
.../FoldIntoPackAndUnpackPatterns.cpp | 28 ++++++++-----------
1 file changed, 12 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index 47d85a6f4f9a59..ee75205cba6cf3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -99,21 +99,19 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
auto packOuterDimsPerm = packOp.getOuterDimsPerm();
auto transposePerm = transposeOp.getPermutation();
- llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+ SmallVector<int64_t> newPackOuterDimsPermVec;
for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
newPackOuterDimsPermVec.push_back(packOuterDimsPerm[transposePerm[i]]);
- // Create a new empty output tensor.
- Type elementType = packOp.getDestType().getElementType();
- auto transposeOpResultType = transposeOp.getResult().getType()[0];
- auto rankedTensorType = transposeOpResultType.dyn_cast<RankedTensorType>();
- Value output = rewriter.create<EmptyOp>(
- transposeOp.getLoc(), rankedTensorType.getShape(), elementType);
+ Value output = packOp.createDestinationTensor(
+ rewriter, transposeOp.getLoc(), packOp.getSource(),
+ packOp.getMixedTiles(), packOp.getInnerDimsPos(),
+ static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
rewriter.replaceOpWithNewOp<PackOp>(
transposeOp, packOp.getSource(), output, packOp.getInnerDimsPos(),
- packOp.getMixedTiles(), std::nullopt,
+ packOp.getMixedTiles(), /*paddingValue=*/std::nullopt,
static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
return success();
@@ -136,21 +134,19 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
auto packOuterDimsPerm = packOp.getOuterDimsPerm();
auto transposePerm = transposeOp.getPermutation();
- llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+ SmallVector<int64_t> newPackOuterDimsPermVec;
for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
newPackOuterDimsPermVec.push_back(transposePerm[packOuterDimsPerm[i]]);
- // Create a new empty output tensor.
- Type elementType = packOp.getDestType().getElementType();
- auto packOpResultType = packOp.getResult().getType();
- auto rankedTensorType = packOpResultType.dyn_cast<RankedTensorType>();
- Value output = rewriter.create<EmptyOp>(
- packOp.getLoc(), rankedTensorType.getShape(), elementType);
+ Value output = packOp.createDestinationTensor(
+ rewriter, packOp.getLoc(), transposeOp.getOperand(0),
+ packOp.getMixedTiles(), packOp.getInnerDimsPos(),
+ static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
rewriter.replaceOpWithNewOp<PackOp>(
packOp, transposeOp.getOperand(0), output, packOp.getInnerDimsPos(),
- packOp.getMixedTiles(), std::nullopt,
+ packOp.getMixedTiles(), /*paddingValue=*/std::nullopt,
static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
return success();
>From 3bed2bb95a4717fb31a2ecdba7c911fbbfbf3e0d Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Fri, 8 Dec 2023 17:55:25 +0000
Subject: [PATCH 6/6] Reference commit for inner tiles transposition
---
.../FoldIntoPackAndUnpackPatterns.cpp | 58 +++++++++++++++++--
1 file changed, 53 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index ee75205cba6cf3..7be771f820a531 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -97,21 +97,69 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
if (!packOp)
return failure();
+ auto packInnerDimsPos = packOp.getInnerDimsPos();
+ auto packInnerTiles = packOp.getStaticInnerTiles();
auto packOuterDimsPerm = packOp.getOuterDimsPerm();
auto transposePerm = transposeOp.getPermutation();
SmallVector<int64_t> newPackOuterDimsPermVec;
+ SmallVector<int64_t> newPackInnerDimsPosVec;
+ SmallVector<int64_t> newPackInnerTilesVec;
+
+ // Variable for storing translated position after considering original
+ // outer_dims_perm and permutation attributes of tensor.pack and
+ // linalg.transpose.
+ int64_t translatedPosition;
+
+ // Process transpose operation for non-tiled outer dimensions of the tensor.
+ for (unsigned int i = 0; i < transposePerm.size() - packInnerTiles.size();
+ ++i) {
+ // If tensor.pack has outer_dims_perm attribute, then consider it during
+ // index translation.
+ if (packOuterDimsPerm.size())
+ translatedPosition = packOuterDimsPerm[transposePerm[i]];
+ else
+ translatedPosition = transposePerm[i];
+
+ // Cannot fold in pack if a tile dimension was transposed with a non-tile
+ // dimension.
+ if (translatedPosition >= transposePerm.size() - packInnerTiles.size())
+ return failure();
- for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
- newPackOuterDimsPermVec.push_back(packOuterDimsPerm[transposePerm[i]]);
+ newPackOuterDimsPermVec.push_back(translatedPosition);
+ }
+
+ // Process transpose operation for tiled inner dimensions of the tensor.
+ for (unsigned int i = transposePerm.size() - packInnerTiles.size();
+ i < transposePerm.size(); ++i) {
+ translatedPosition =
+ transposePerm[i] - (transposePerm.size() - packInnerTiles.size());
+
+ newPackInnerTilesVec.push_back(packInnerTiles[translatedPosition]);
+ newPackInnerDimsPosVec.push_back(packInnerDimsPos[translatedPosition]);
+ }
+
+ llvm::SmallVector<OpFoldResult, 4> opFoldResultsTiles;
+ opFoldResultsTiles.reserve(newPackInnerTilesVec.size());
+
+ llvm::transform(
+ newPackInnerTilesVec, std::back_inserter(opFoldResultsTiles),
+ [&rewriter](int64_t value) {
+ return IntegerAttr::get(IndexType::get(rewriter.getContext()), value);
+ });
+
+ llvm::ArrayRef<OpFoldResult> newPackInnerTilesArrayRef(opFoldResultsTiles);
Value output = packOp.createDestinationTensor(
rewriter, transposeOp.getLoc(), packOp.getSource(),
- packOp.getMixedTiles(), packOp.getInnerDimsPos(),
+ newPackInnerTilesArrayRef,
+ static_cast<llvm::ArrayRef<int64_t>>(newPackInnerDimsPosVec),
static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
rewriter.replaceOpWithNewOp<PackOp>(
- transposeOp, packOp.getSource(), output, packOp.getInnerDimsPos(),
- packOp.getMixedTiles(), /*paddingValue=*/std::nullopt,
+ transposeOp, packOp.getSource(), output,
+ static_cast<llvm::ArrayRef<int64_t>>(newPackInnerDimsPosVec),
+ newPackInnerTilesArrayRef,
+ /*paddingValue=*/std::nullopt,
static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
return success();
More information about the llvm-commits
mailing list