[Mlir-commits] [mlir] [mlir][tensor] Improve tensor.pack simplication pattern. (PR #76606)
Han-Chung Wang
llvmlistbot at llvm.org
Sat Dec 30 11:42:39 PST 2023
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/76606
>From fb3499336d0867f9d79b3c86a0e10d194afa9a7b Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Sat, 30 Dec 2023 06:26:58 +0000
Subject: [PATCH] [mlir][tensor] Improve tensor.pack simplication pattern.
We can rewrite the op to tensor.expand_shape if the packing only happens
on inner most dimension.
---
.../Transforms/PackAndUnpackPatterns.cpp | 14 +++++-
.../Dialect/Tensor/simplify-pack-unpack.mlir | 50 ++++++++++++++++---
2 files changed, 54 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index 67651a2e38c82d..e20450c95ffd5f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -35,10 +35,20 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
+ if (packOp.getPaddingValue())
+ return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+ if (!packOp.getOuterDimsPerm().empty())
+ return rewriter.notifyMatchFailure(packOp, "expects no outer_dims_perm");
+
RankedTensorType sourceType = packOp.getSourceType();
RankedTensorType destType = packOp.getDestType();
- if (sourceType.getRank() != 1 || packOp.getPaddingValue())
- return failure();
+ ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
+ if (dimsPos.size() != 1 || (dimsPos[0] + 1 != sourceType.getRank())) {
+ return rewriter.notifyMatchFailure(
+ packOp, "expects packing at the innermost dimension");
+ }
+
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 049076a67bae53..bdfe18acd86c53 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -1,9 +1,9 @@
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-unpack-patterns" %s | FileCheck %s
-// CHECK: func.func @single_dim_packing(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
-// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
+// CHECK-LABEL: func.func @single_dim_packing(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
+// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
%empty = tensor.empty() : tensor<8x32xf32>
%0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256xf32> -> tensor<8x32xf32>
@@ -12,13 +12,47 @@ func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
// -----
-// CHECK: func.func @single_dim_packing_with_padding(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>)
-// CHECK-NOT: tensor.expand_shape
-// CHECK: tensor.pack
+// CHECK-LABEL: func.func @single_dim_packing_with_padding(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>)
+// CHECK-NOT: tensor.expand_shape
+// CHECK: tensor.pack
func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x32xf32> {
%empty = tensor.empty() : tensor<8x32xf32>
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.pack %arg0 padding_value(%cst : f32) inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<255xf32> -> tensor<8x32xf32>
return %0 : tensor<8x32xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @single_last_inner_dim_packing(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x256xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
+// CHECK: return %[[EXPANDED]] : tensor<5x8x32xf32>
+func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
+ %empty = tensor.empty() : tensor<5x8x32xf32>
+ %0 = tensor.pack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<5x8x32xf32>
+ return %0 : tensor<5x8x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @packing_with_outer_dims_perm(
+// CHECK-NOT: tensor.expand_shape
+// CHECK: tensor.pack
+func.func @packing_with_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<8x5x32xf32> {
+ %empty = tensor.empty() : tensor<8x5x32xf32>
+ %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<8x5x32xf32>
+ return %0 : tensor<8x5x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_first_inner_dim_packing(
+// CHECK-NOT: tensor.expand_shape
+// CHECK: tensor.pack
+func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x5x32xf32> {
+ %empty = tensor.empty() : tensor<8x5x32xf32>
+ %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
+ return %0 : tensor<8x5x32xf32>
+}
More information about the Mlir-commits
mailing list