[Mlir-commits] [mlir] [mlir][tensor] Add a pattern to simplify tensor.unpack to collpase shape (PR #76607)

Han-Chung Wang llvmlistbot at llvm.org
Sat Dec 30 00:05:00 PST 2023


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/76607

>From cf6fa4ed5c3cde4a0fac788932e8a5d570d36f62 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Sat, 30 Dec 2023 05:39:35 +0000
Subject: [PATCH 1/3] [mlir][tensor] Centralize pack/unpack related patterns.

The revisino moves pack/unpack related patterns to PackAndUnpackPatterns. This follows the convention like other tensor ops.
---
 mlir/include/mlir/Dialect/Tensor/IR/Tensor.h  |  3 --
 .../Dialect/Tensor/Transforms/Transforms.h    |  5 +++
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 38 -------------------
 .../Dialect/Tensor/Transforms/CMakeLists.txt  |  2 +-
 ...Patterns.cpp => PackAndUnpackPatterns.cpp} | 34 +++++++++++++++++
 ...or-pack.mlir => simplify-pack-unpack.mlir} |  2 +-
 .../Dialect/Tensor/TestTensorTransforms.cpp   | 14 +++----
 7 files changed, 48 insertions(+), 50 deletions(-)
 rename mlir/lib/Dialect/Tensor/Transforms/{FoldIntoPackAndUnpackPatterns.cpp => PackAndUnpackPatterns.cpp} (80%)
 rename mlir/test/Dialect/Tensor/{simplify-tensor-pack.mlir => simplify-pack-unpack.mlir} (95%)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 06642adda42b38..0a21c9922b223b 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -163,9 +163,6 @@ void populateFoldConstantExtractSlicePatterns(
           return false;
         });
 
-/// Patterns to simplify tensor.pack.
-void populateSimplifyTensorPack(RewritePatternSet &patterns);
-
 } // namespace tensor
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 44b8377bd6aad9..35b519e790d1c3 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -74,6 +74,11 @@ void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
 /// that it can be bufferized into a sequence of copies.
 void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns that simplify `tensor.pack` and
+/// `tensor.unpack` operations.
+/// TODO: Add a pattern to convert tensor.unpack op to tensor.collapse_shape op.
+void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that fold operations like `tensor.pad`
 /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
 /// respectively.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7c35dd4d953619..816e6ba8fed94e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3466,44 +3466,6 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
 // PackOp/UnPackOp Common
 //===----------------------------------------------------------------------===//
 
-namespace {
-
-/// Packing one-dimensional tensor can be expressed as an expand shape op.
-struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
-  using OpRewritePattern<PackOp>::OpRewritePattern;
-
-  Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
-                     Type newOperandType, ArrayAttr reassociation) const {
-    if (operand.getType() == newOperandType)
-      return operand;
-    return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
-                                                  reassociation);
-  }
-
-  LogicalResult matchAndRewrite(PackOp packOp,
-                                PatternRewriter &rewriter) const override {
-    RankedTensorType sourceType = packOp.getSourceType();
-    RankedTensorType destType = packOp.getDestType();
-    if (sourceType.getRank() != 1 || packOp.getPaddingValue())
-      return failure();
-    auto reassociation =
-        getReassociationIndicesForReshape(sourceType, destType);
-    if (!reassociation)
-      return failure();
-    Value expanded = insertExpand(
-        rewriter, packOp.getLoc(), packOp.getSource(), destType,
-        getReassociationIndicesAttribute(rewriter, *reassociation));
-    rewriter.replaceOp(packOp, expanded);
-    return success();
-  }
-};
-
-} // namespace
-
-void mlir::tensor::populateSimplifyTensorPack(RewritePatternSet &patterns) {
-  patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
-}
-
 template <typename OpTy>
 static LogicalResult
 reifyResultShapesImpl(OpTy op, OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index d233ab7a0e8974..cbc0d499d9d52c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -4,10 +4,10 @@ add_mlir_dialect_library(MLIRTensorTransforms
   ConcatOpPatterns.cpp
   EmptyOpPatterns.cpp
   ExtractSliceFromReshapeUtils.cpp
-  FoldIntoPackAndUnpackPatterns.cpp
   FoldTensorSubsetOps.cpp
   IndependenceTransforms.cpp
   MergeConsecutiveInsertExtractSlicePatterns.cpp
+  PackAndUnpackPatterns.cpp
   ReshapePatterns.cpp
   RewriteAsConstant.cpp
   SwapExtractSliceWithProducerPatterns.cpp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
similarity index 80%
rename from mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
rename to mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index e4509b331beeac..67651a2e38c82d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -21,6 +21,36 @@ static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
       ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
 }
 
+/// Packing one-dimensional tensor can be expressed as an expand shape op.
+struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
+                     Type newOperandType, ArrayAttr reassociation) const {
+    if (operand.getType() == newOperandType)
+      return operand;
+    return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
+                                                  reassociation);
+  }
+
+  LogicalResult matchAndRewrite(PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    RankedTensorType sourceType = packOp.getSourceType();
+    RankedTensorType destType = packOp.getDestType();
+    if (sourceType.getRank() != 1 || packOp.getPaddingValue())
+      return failure();
+    auto reassociation =
+        getReassociationIndicesForReshape(sourceType, destType);
+    if (!reassociation)
+      return failure();
+    Value expanded = insertExpand(
+        rewriter, packOp.getLoc(), packOp.getSource(), destType,
+        getReassociationIndicesAttribute(rewriter, *reassociation));
+    rewriter.replaceOp(packOp, expanded);
+    return success();
+  }
+};
+
 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
 /// the pad op has zero low paddings, or if `pack` has no padding values.
 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -150,5 +180,9 @@ void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
       patterns.getContext());
 }
 
+void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
+  patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
+}
+
 } // namespace tensor
 } // namespace mlir
diff --git a/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
similarity index 95%
rename from mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir
rename to mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 75eb33ed033b9e..049076a67bae53 100644
--- a/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-patterns" %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-unpack-patterns" %s | FileCheck %s
 
 // CHECK: func.func @single_dim_packing(
 // CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 3e142155df8d9b..b907f77e910825 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -84,9 +84,9 @@ struct TestTensorTransforms
           "the extract_slice of collapse_shape pattern"),
       llvm::cl::init(false)};
 
-  Option<bool> testSimplifyPackPatterns{
-      *this, "test-simplify-pack-patterns",
-      llvm::cl::desc("Test patterns to simplify tensor.pack"),
+  Option<bool> testSimplifyPackUnpackPatterns{
+      *this, "test-simplify-pack-unpack-patterns",
+      llvm::cl::desc("Test patterns to simplify tensor.pack and tensor.unpack"),
       llvm::cl::init(false)};
 
   Option<bool> testTrackingListener{
@@ -137,9 +137,9 @@ applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
 }
 
-static void applySimplifyPackPatterns(Operation *rootOp) {
+static void applySimplifyPackUnpackPatterns(Operation *rootOp) {
   RewritePatternSet patterns(rootOp->getContext());
-  tensor::populateSimplifyTensorPack(patterns);
+  tensor::populateSimplifyPackAndUnpackPatterns(patterns);
   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
 }
 
@@ -376,8 +376,8 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
 
 void TestTensorTransforms::runOnOperation() {
   Operation *rootOp = getOperation();
-  if (testSimplifyPackPatterns)
-    applySimplifyPackPatterns(rootOp);
+  if (testSimplifyPackUnpackPatterns)
+    applySimplifyPackUnpackPatterns(rootOp);
   if (testFoldConstantExtractSlice)
     applyFoldConstantExtractSlicePatterns(rootOp);
   if (testFoldConsecutiveInsertExtractSlice)

>From 9cf7bb7d05f95601e746c431be6ea32a8bc5c37f Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Sat, 30 Dec 2023 06:26:58 +0000
Subject: [PATCH 2/3] [mlir][tensor] Improve tensor.pack simplication pattern.

We can rewrite the op to tensor.expand_shape if the packing only happens
on inner most dimension.
---
 .../Transforms/PackAndUnpackPatterns.cpp      | 14 +++++-
 .../Dialect/Tensor/simplify-pack-unpack.mlir  | 50 ++++++++++++++++---
 2 files changed, 54 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index 67651a2e38c82d..e20450c95ffd5f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -35,10 +35,20 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
 
   LogicalResult matchAndRewrite(PackOp packOp,
                                 PatternRewriter &rewriter) const override {
+    if (packOp.getPaddingValue())
+      return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+    if (!packOp.getOuterDimsPerm().empty())
+      return rewriter.notifyMatchFailure(packOp, "expects no outer_dims_perm");
+
     RankedTensorType sourceType = packOp.getSourceType();
     RankedTensorType destType = packOp.getDestType();
-    if (sourceType.getRank() != 1 || packOp.getPaddingValue())
-      return failure();
+    ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
+    if (dimsPos.size() != 1 || (dimsPos[0] + 1 != sourceType.getRank())) {
+      return rewriter.notifyMatchFailure(
+          packOp, "expects packing at the innermost dimension");
+    }
+
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 049076a67bae53..bdfe18acd86c53 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-unpack-patterns" %s | FileCheck %s
 
-// CHECK: func.func @single_dim_packing(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
-// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
+// CHECK-LABEL: func.func @single_dim_packing(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<256xf32>)
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<8x32xf32>
 func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
   %empty = tensor.empty() : tensor<8x32xf32>
   %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256xf32> -> tensor<8x32xf32>
@@ -12,13 +12,47 @@ func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
 
 // -----
 
-// CHECK: func.func @single_dim_packing_with_padding(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>)
-// CHECK-NOT: tensor.expand_shape
-// CHECK: tensor.pack
+// CHECK-LABEL: func.func @single_dim_packing_with_padding(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<255xf32>)
+// CHECK-NOT:     tensor.expand_shape
+// CHECK:         tensor.pack
 func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x32xf32> {
   %empty = tensor.empty() : tensor<8x32xf32>
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.pack %arg0 padding_value(%cst : f32) inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<255xf32> -> tensor<8x32xf32>
   return %0 : tensor<8x32xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @single_last_inner_dim_packing(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<5x256xf32>)
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<5x8x32xf32>
+func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
+  %empty = tensor.empty() : tensor<5x8x32xf32>
+  %0 = tensor.pack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<5x8x32xf32>
+  return %0 : tensor<5x8x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @packing_with_outer_dims_perm(
+// CHECK-NOT:     tensor.expand_shape
+// CHECK:         tensor.pack
+func.func @packing_with_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<8x5x32xf32> {
+  %empty = tensor.empty() : tensor<8x5x32xf32>
+  %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<8x5x32xf32>
+  return %0 : tensor<8x5x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_first_inner_dim_packing(
+// CHECK-NOT:     tensor.expand_shape
+// CHECK:         tensor.pack
+func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x5x32xf32> {
+  %empty = tensor.empty() : tensor<8x5x32xf32>
+  %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
+  return %0 : tensor<8x5x32xf32>
+}

>From 91c43c7e2ae488b4c2c8eba7c652c24a058282d2 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Sat, 30 Dec 2023 06:53:25 +0000
Subject: [PATCH 3/3] [mlir][tensor] Add a pattern to simplify tensor.unpack to
 collpase shape

---
 .../Dialect/Tensor/Transforms/Transforms.h    |  1 -
 .../Transforms/PackAndUnpackPatterns.cpp      | 44 +++++++++++-
 .../Dialect/Tensor/simplify-pack-unpack.mlir  | 72 +++++++++++++++++++
 3 files changed, 115 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 35b519e790d1c3..e8a09c4741043b 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -76,7 +76,6 @@ void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
 
 /// Populates `patterns` with patterns that simplify `tensor.pack` and
 /// `tensor.unpack` operations.
-/// TODO: Add a pattern to convert tensor.unpack op to tensor.collapse_shape op.
 void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
 
 /// Populates `patterns` with patterns that fold operations like `tensor.pad`
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index e20450c95ffd5f..cfd838e85c1b80 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -61,6 +61,47 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
   }
 };
 
+struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
+  using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+  Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
+                       Type newOperandType, ArrayAttr reassociation) const {
+    if (operand.getType() == newOperandType)
+      return operand;
+    return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
+                                                    operand, reassociation);
+  }
+
+  LogicalResult matchAndRewrite(UnPackOp unpackOp,
+                                PatternRewriter &rewriter) const override {
+    if (!unpackOp.getOuterDimsPerm().empty()) {
+      return rewriter.notifyMatchFailure(unpackOp,
+                                         "expects no outer_dims_perm");
+    }
+
+    RankedTensorType sourceType = unpackOp.getSourceType();
+    RankedTensorType destType = unpackOp.getDestType();
+    if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
+      return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
+
+    ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
+    if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
+      return rewriter.notifyMatchFailure(
+          unpackOp, "expects unpacking at the innermost dimension");
+    }
+
+    auto reassociation =
+        getReassociationIndicesForReshape(sourceType, destType);
+    if (!reassociation)
+      return failure();
+    Value collapsed = insertCollapse(
+        rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
+        getReassociationIndicesAttribute(rewriter, *reassociation));
+    rewriter.replaceOp(unpackOp, collapsed);
+    return success();
+  }
+};
+
 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
 /// the pad op has zero low paddings, or if `pack` has no padding values.
 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -191,7 +232,8 @@ void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
 }
 
 void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
-  patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
+  patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
+      patterns.getContext());
 }
 
 } // namespace tensor
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index bdfe18acd86c53..b78ab9bb3fd87e 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -56,3 +56,75 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x
   %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
   return %0 : tensor<8x5x32xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_1d_to_collapse
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<8x32xf32>)
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<8x32xf32> into tensor<256xf32>
+// CHECK:         return %[[COLLAPSED]]
+func.func @unpack_1d_to_collapse(%arg0: tensor<8x32xf32>) -> tensor<256xf32> {
+  %empty = tensor.empty() : tensor<256xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<256xf32>
+  return %0 : tensor<256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_to_partial_slice
+// CHECK-NOT:     tensor.collapse
+// CHECK:         tensor.unpack
+func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
+  %empty = tensor.empty() : tensor<255xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<255xf32>
+  return %0 : tensor<255xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_dynamic
+// CHECK-NOT:     tensor.collapse
+// CHECK:         tensor.unpack
+func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
+  %c32 = arith.constant 32 : index
+  %c0 = arith.constant 0 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x32xf32>
+  %size = arith.muli %d0, %c32 : index
+  %empty = tensor.empty(%size) : tensor<?xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<?x32xf32> -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_last_inner_dim_unpacking(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<5x8x32xf32>)
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x8x32xf32> into tensor<5x256xf32>
+// CHECK:         return %[[COLLAPSED]] : tensor<5x256xf32>
+func.func @single_last_inner_dim_unpacking(%arg0: tensor<5x8x32xf32>) -> tensor<5x256xf32> {
+  %empty = tensor.empty() : tensor<5x256xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x8x32xf32> -> tensor<5x256xf32>
+  return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpacking_with_outer_dims_perm(
+// CHECK-NOT:     tensor.collpase_shape
+// CHECK:         tensor.unpack
+func.func @unpacking_with_outer_dims_perm(%arg0: tensor<8x5x32xf32>) -> tensor<5x256xf32> {
+  %empty = tensor.empty() : tensor<5x256xf32>
+  %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<5x256xf32>
+  return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_first_inner_dim_unpacking(
+// CHECK-NOT:     tensor.collapse_shape
+// CHECK:         tensor.unpack
+func.func @single_first_inner_dim_unpacking(%arg0: tensor<8x5x32xf32>) -> tensor<256x5xf32> {
+  %empty = tensor.empty() : tensor<256x5xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<256x5xf32>
+  return %0 : tensor<256x5xf32>
+}



More information about the Mlir-commits mailing list