[Mlir-commits] [mlir] 78348b6 - [mlir][tensor] Improve tensor.pack simplication pattern. (#76606)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 2 09:34:28 PST 2024


Author: Han-Chung Wang
Date: 2024-01-02T09:34:24-08:00
New Revision: 78348b691504bf9ec212add73cc37d2fd8371f83

URL: https://github.com/llvm/llvm-project/commit/78348b691504bf9ec212add73cc37d2fd8371f83
DIFF: https://github.com/llvm/llvm-project/commit/78348b691504bf9ec212add73cc37d2fd8371f83.diff

LOG: [mlir][tensor] Improve tensor.pack simplication pattern.  (#76606)

A tensor.pack op can be rewritten to a tensor.expand_shape op if the
packing only happens on inner most dimension.

This also formats the lit checks better.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
    mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index 67651a2e38c82d..e20450c95ffd5f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -35,10 +35,20 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
 
   LogicalResult matchAndRewrite(PackOp packOp,
                                 PatternRewriter &rewriter) const override {
+    if (packOp.getPaddingValue())
+      return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+    if (!packOp.getOuterDimsPerm().empty())
+      return rewriter.notifyMatchFailure(packOp, "expects no outer_dims_perm");
+
     RankedTensorType sourceType = packOp.getSourceType();
     RankedTensorType destType = packOp.getDestType();
-    if (sourceType.getRank() != 1 || packOp.getPaddingValue())
-      return failure();
+    ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
+    if (dimsPos.size() != 1 || (dimsPos[0] + 1 != sourceType.getRank())) {
+      return rewriter.notifyMatchFailure(
+          packOp, "expects packing at the innermost dimension");
+    }
+
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)

diff  --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 049076a67bae53..bdfe18acd86c53 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-unpack-patterns" %s | FileCheck %s
 
-// CHECK: func.func @single_dim_packing(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
-// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
+// CHECK-LABEL: func.func @single_dim_packing(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<256xf32>)
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<8x32xf32>
 func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
   %empty = tensor.empty() : tensor<8x32xf32>
   %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256xf32> -> tensor<8x32xf32>
@@ -12,13 +12,47 @@ func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
 
 // -----
 
-// CHECK: func.func @single_dim_packing_with_padding(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>)
-// CHECK-NOT: tensor.expand_shape
-// CHECK: tensor.pack
+// CHECK-LABEL: func.func @single_dim_packing_with_padding(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<255xf32>)
+// CHECK-NOT:     tensor.expand_shape
+// CHECK:         tensor.pack
 func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x32xf32> {
   %empty = tensor.empty() : tensor<8x32xf32>
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.pack %arg0 padding_value(%cst : f32) inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<255xf32> -> tensor<8x32xf32>
   return %0 : tensor<8x32xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @single_last_inner_dim_packing(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<5x256xf32>)
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<5x8x32xf32>
+func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
+  %empty = tensor.empty() : tensor<5x8x32xf32>
+  %0 = tensor.pack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<5x8x32xf32>
+  return %0 : tensor<5x8x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @packing_with_outer_dims_perm(
+// CHECK-NOT:     tensor.expand_shape
+// CHECK:         tensor.pack
+func.func @packing_with_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<8x5x32xf32> {
+  %empty = tensor.empty() : tensor<8x5x32xf32>
+  %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<8x5x32xf32>
+  return %0 : tensor<8x5x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_first_inner_dim_packing(
+// CHECK-NOT:     tensor.expand_shape
+// CHECK:         tensor.pack
+func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x5x32xf32> {
+  %empty = tensor.empty() : tensor<8x5x32xf32>
+  %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
+  return %0 : tensor<8x5x32xf32>
+}


        


More information about the Mlir-commits mailing list