[Mlir-commits] [mlir] 1ff6d9d - [mlir][linalg] Take artificial padding into account for pack/unpack folding. (#150272)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 24 13:55:10 PDT 2025


Author: Han-Chung Wang
Date: 2025-07-24T13:55:07-07:00
New Revision: 1ff6d9daec66fb151b9691386c9dad0209648465

URL: https://github.com/llvm/llvm-project/commit/1ff6d9daec66fb151b9691386c9dad0209648465
DIFF: https://github.com/llvm/llvm-project/commit/1ff6d9daec66fb151b9691386c9dad0209648465.diff

LOG: [mlir][linalg] Take artificial padding into account for pack/unpack folding. (#150272)

The revision only folds the tensor.pad/extract_slice op into
linalg.pack/unpack ops only when it is safe to fold. It is not valid to
have artificial padding.

The documentation improvement and verifier update will be done in a
separate PR (i.e., https://github.com/llvm/llvm-project/pull/149624).
The revision is a step towards it.

---------

Signed-off-by: hanhanW <hanhan0912 at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index bb0ac414bcc2d..62c04bb2ee1ab 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_LINALG_IR_LINALG_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
@@ -144,4 +145,17 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.h.inc"
 
+namespace mlir {
+namespace linalg {
+
+/// Returns the outer shape in the packed domain before applying the
+/// transposition.
+template <typename OpTy,
+          typename = std::enable_if_t<std::is_same_v<OpTy, linalg::PackOp> ||
+                                      std::is_same_v<OpTy, linalg::UnPackOp>>>
+SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
+
+} // namespace linalg
+} // namespace mlir
+
 #endif // MLIR_DIALECT_LINALG_IR_LINALG_H

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index c384e8b638382..fa572024ff72b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -360,6 +360,10 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
                                    ArrayRef<int64_t> innerPermutation,
                                    ArrayRef<int64_t> outerPermutation);
 
+    /// Returns true if it is statically known that the `sliceOp` result shape
+    /// is compatible with the `unPackOp`. I.e., it does not drop any tile.
+    bool canFoldSliceOp(tensor::ExtractSliceOp sliceOp);
+
     /// Check if this UnPackOp is like a simple unpad operation.
     /// In other words, this operation:
     /// 1. drops useless dimensions (dimension of size 1), and

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d5e2ed6bad7b1..4fee81aa2ef67 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4492,6 +4492,29 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
 //===----------------------------------------------------------------------===//
 // PackOp/UnPackOp Common
 //===----------------------------------------------------------------------===//
+
+template <typename OpTy, typename>
+SmallVector<int64_t>
+getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
+  RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
+                                    ? packOrUnPack.getDestType()
+                                    : packOrUnPack.getSourceType();
+  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
+                                      ? packOrUnPack.getSourceType()
+                                      : packOrUnPack.getDestType();
+  SmallVector<int64_t> result(
+      packedType.getShape().take_front(unpackedType.getRank()));
+  if (!packOrUnPack.getOuterDimsPerm().empty()) {
+    applyPermutationToVector(
+        result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
+  }
+  return result;
+}
+template SmallVector<int64_t>
+    getPackedOuterShapeWithoutTransposition<PackOp>(PackOp);
+template SmallVector<int64_t>
+    getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);
+
 // Given the (potentially) updated packed type, `newPackedTy`, generates an
 // updated mixed-tile-sizes attribute. A tile size is updated only
 // when:
@@ -5452,11 +5475,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
   if (unPackOp->hasOneUse()) {
     auto extractSliceUser =
         dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
-    if (extractSliceUser &&
-        areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
-        areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
-        extractSliceUser.getSourceType().getRank() ==
-            extractSliceUser.getResultType().getRank()) {
+    if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
       OpBuilder::InsertionGuard g(rewriter);
       rewriter.setInsertionPoint(unPackOp);
       auto newDest = tensor::ExtractSliceOp::create(
@@ -5499,6 +5518,32 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
   return failure();
 }
 
+bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
+  // Rank-reduced folding is not supported.
+  if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
+    return false;
+  if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
+      !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
+    return false;
+  RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
+  SmallVector<int64_t> outerShapeWithoutTranspose =
+      getPackedOuterShapeWithoutTransposition(*this);
+  for (auto [pos, tileSize] :
+       llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
+    if (unpackedTypeAfterFold.isDynamicDim(pos))
+      return false;
+    if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+      return false;
+    if (ShapedType::isDynamic(tileSize))
+      return false;
+    int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+                          unpackedTypeAfterFold.getDimSize(pos);
+    if (paddingSize >= tileSize)
+      return false;
+  }
+  return true;
+}
+
 bool UnPackOp::isLikeUnPad() {
   RankedTensorType packedTensorType = getSourceType();
   return isLikePadUnPad(*this, packedTensorType);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0415057eda86b..a45a4e314e511 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -220,6 +220,33 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
       if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
         return failure();
 
+    // Folding is not allowed if it were to introduce artificial padding.
+    // Folding is also disabled in the case of dynamic dimensions and/or tile
+    // sizes - that is because it would be impossible to compute the padding
+    // size and hence to establish whether "artificial" padding would be
+    // created.
+    RankedTensorType unpackedType = packOp.getSourceType();
+    SmallVector<int64_t> outerShapeWithoutTranspose =
+        getPackedOuterShapeWithoutTransposition(packOp);
+    for (auto [pos, tileSize, high] :
+         llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
+                         padOp.getMixedHighPad())) {
+      if (unpackedType.isDynamicDim(pos))
+        return failure();
+      if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+        return failure();
+      if (ShapedType::isDynamic(tileSize))
+        return failure();
+      std::optional<int64_t> cstHigh = getConstantIntValue(high);
+      if (!cstHigh)
+        return failure();
+      int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+                            unpackedType.getDimSize(pos);
+      // Do not fold the op if it requires artificial padding.
+      if (paddingSize + cstHigh.value() >= tileSize)
+        return failure();
+    }
+
     rewriter.replaceOpWithNewOp<PackOp>(
         packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
         packOp.getMixedTiles(), constantPaddingValue,
@@ -251,17 +278,8 @@ struct FoldUnpackWithExtractSliceOp
     if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
       return failure();
 
-    if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "rank-reduced folding is not supported");
-    }
-
-    // Check all offsets are zeros, and all strides are ones.
-    if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
-        !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "expects offsets to be 0s and strides to be 1s");
-    }
+    if (!unpackOp.canFoldSliceOp(sliceOp))
+      return failure();
 
     // Create a new empty output tensor.
     Type elementType = unpackOp.getDestType().getElementType();

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7284ae7dbd673..9cbb56e4de884 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1889,31 +1889,84 @@ func.func @fold_cast_unpack_dynamic_tile_size(
 // linalg.unpack + tensor.extract_slice
 //===----------------------------------------------------------------------===//
 
-func.func @fold_extract_slice_into_unpack(
-    %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
-) -> tensor<28x28x?xf32> {
+func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x28x10xf32> {
   %unpack = linalg.unpack %src
       outer_dims_perm = [0, 1, 2]
       inner_dims_pos = [1, 2]
       inner_tiles = [16, 16]
-      into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+      into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
   %extracted_slice = tensor.extract_slice %unpack
-      [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
-  return %extracted_slice : tensor<28x28x?xf32>
+      [0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
+  return %extracted_slice : tensor<28x28x10xf32>
 }
+// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_trailing_dim
+//  CHECK-SAME:     %[[SRC:[a-zA-Z0-9]+]]
+//  CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
+//       CHECK:   %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
+//  CHECK-SAME:     [0, 0, 0] [28, 28, 10] [1, 1, 1]
+//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+//  CHECK-SAME:       into %[[DEST_SLICE]]
+//       CHECK:   return %[[UNPACK]]
+
+// -----
+
+// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
 
-// CHECK-LABEL: func @fold_extract_slice_into_unpack
-//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
-//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32x?xf32>
-//  CHECK-SAME:     %[[SIZE:.+]]: index
+func.func @fold_extract_slice_into_unpack_slicing_dim_1(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x17x15xf32> {
+  %unpack = linalg.unpack %src
+      inner_dims_pos = [1, 2]
+      inner_tiles = [16, 16]
+      into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0, 0] [28, 17, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x17x15xf32>
+  return %extracted_slice : tensor<28x17x15xf32>
+}
+// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_dim_1(
+//  CHECK-SAME:     %[[SRC:[a-zA-Z0-9]+]]
+//  CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
 //       CHECK:   %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
-//  CHECK-SAME:     [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
+//  CHECK-SAME:     [0, 0, 0] [28, 17, 15] [1, 1, 1]
 //       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
 //  CHECK-SAME:       into %[[DEST_SLICE]]
 //       CHECK:   return %[[UNPACK]]
 
 // -----
 
+// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
+
+func.func @no_fold_extract_slice_into_unpack_artificial_padding(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x16x15xf32> {
+  %unpack = linalg.unpack %src
+      inner_dims_pos = [1, 2]
+      inner_tiles = [16, 16]
+      into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32>
+  return %extracted_slice : tensor<28x16x15xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_artificial_padding
+//       CHECK:   linalg.unpack
+//       CHECK:   tensor.extract_slice
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_dynamic(
+    %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+  %unpack = linalg.unpack %src
+      outer_dims_perm = [0, 1, 2]
+      inner_dims_pos = [1, 2]
+      inner_tiles = [16, 16]
+      into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+  return %extracted_slice : tensor<28x28x?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
+//       CHECK:   linalg.unpack
+//       CHECK:   tensor.extract_slice
+
+// -----
+
 func.func @no_fold_extract_slice_into_unpack_rank_reducing(
     %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
 ) -> tensor<28xf32> {

diff  --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 16efa73f87a2a..9da2dea0bbd3c 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -1,22 +1,92 @@
 // RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack  %s | FileCheck %s
 // RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control  %s | FileCheck %s --check-prefix=CONTROL
 
-func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(%arg0 : tensor<28x2x1x16x16xf32>) -> tensor<28x28x10xf32> {
+  %empty = tensor.empty() : tensor<28x28x15xf32>
+  %unpack = linalg.unpack %arg0
+      inner_dims_pos = [1, 2]
+      inner_tiles = [16, 16]
+      into %empty : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
+  return %extracted_slice : tensor<28x28x10xf32>
+}
+// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_trailing_dim
+//  CHECK-SAME:     %[[SRC:[a-zA-Z0-9]+]]
+//       CHECK:   %[[DEST_SLICE:.+]] = tensor.empty() : tensor<28x28x10xf32>
+//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+//  CHECK-SAME:       into %[[DEST_SLICE]]
+//       CHECK:   return %[[UNPACK]]
+
+// -----
+
+// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
+
+func.func @fold_extract_slice_into_unpack_slicing_dim_1(%arg0 : tensor<28x2x1x16x16xf32>) -> tensor<28x17x15xf32> {
+  %empty = tensor.empty() : tensor<28x28x15xf32>
+  %unpack = linalg.unpack %arg0
+      inner_dims_pos = [1, 2]
+      inner_tiles = [16, 16]
+      into %empty : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0, 0] [28, 17, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x17x15xf32>
+  return %extracted_slice : tensor<28x17x15xf32>
+}
+// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_dim_1(
+//  CHECK-SAME:     %[[SRC:[a-zA-Z0-9]+]]
+//       CHECK:   %[[DEST_SLICE:.+]] = tensor.empty() : tensor<28x17x15xf32>
+//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+//  CHECK-SAME:       into %[[DEST_SLICE]]
+//       CHECK:   return %[[UNPACK]]
+
+// -----
+
+// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
+
+func.func @no_fold_extract_slice_into_unpack_artificial_padding(%arg0 : tensor<28x2x1x16x16xf32>) -> tensor<28x16x15xf32> {
+  %empty = tensor.empty() : tensor<28x28x15xf32>
+  %unpack = linalg.unpack %arg0
+      inner_dims_pos = [1, 2]
+      inner_tiles = [16, 16]
+      into %empty : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32>
+  return %extracted_slice : tensor<28x16x15xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_artificial_padding
+//       CHECK:   linalg.unpack
+//       CHECK:   tensor.extract_slice
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_dynamic(
+    %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+  %unpack = linalg.unpack %src
+      outer_dims_perm = [0, 1, 2]
+      inner_dims_pos = [1, 2]
+      inner_tiles = [16, 16]
+      into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+  return %extracted_slice : tensor<28x28x?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
+//       CHECK:   linalg.unpack
+//       CHECK:   tensor.extract_slice
+
+// -----
+
+func.func @nofold_dynamic_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
     %arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
   %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
       : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
   %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
-//      CHECK: func @fold_unpack_slice(
-// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x8x4xf32>
-// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
-// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
-//      CHECK:   %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32>
-//      CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4]
-// CHECK-SAME:       into %[[INIT]]
-//      CHECK:   return %[[UNPACK]]
+// CHECK-LABEL: func @nofold_dynamic_unpack_slice(
+//       CHECK:   linalg.unpack
+//       CHECK:   tensor.extract_slice
 
 // -----
 
@@ -59,48 +129,62 @@ func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 :
 
 // -----
 
-func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
-  %c0 = arith.constant 0 : index
+func.func @fold_pad_pack(%src: tensor<9x16xf32>) -> tensor<2x1x8x32xf32> {
   %cst = arith.constant 0.000000e+00 : f32
-  %padded = tensor.pad %src low[0, 0] high[15, 0] {
+  %padded = tensor.pad %src low[0, 0] high[7, 0] {
   ^bb0(%arg0: index, %arg1: index):
     tensor.yield %cst : f32
-  } : tensor<16641x16xf32> to tensor<16656x16xf32>
-  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+  } : tensor<9x16xf32> to tensor<16x16xf32>
+  %empty = tensor.empty() : tensor<2x1x8x32xf32>
   %pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
-      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
-  return %pack : tensor<2082x1x8x32xf32>
+      : tensor<16x16xf32> -> tensor<2x1x8x32xf32>
+  return %pack : tensor<2x1x8x32xf32>
 }
-// CHECK-LABEL: func.func @pad_pack
+// CHECK-LABEL: func.func @fold_pad_pack
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK:         %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32>
+// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<2x1x8x32xf32>
 // CHECK:         %[[PACK:.+]] = linalg.pack %[[SRC]]
 // CHECK-SAME:      padding_value(%[[PAD_VAL]] : f32)
 // CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]]
 
 // -----
 
-func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
-  %c0 = arith.constant 0 : index
+func.func @nofold_pad_pack_artificial_padding(%src: tensor<9x16xf32>) -> tensor<3x1x8x32xf32> {
   %cst = arith.constant 0.000000e+00 : f32
-  %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
+  %padded = tensor.pad %src low[0, 0] high[8, 0] {
   ^bb0(%arg0: index, %arg1: index):
     tensor.yield %cst : f32
-  } : tensor<16641x16xf32> to tensor<16656x16xf32>
+  } : tensor<9x16xf32> to tensor<17x16xf32>
+  %empty = tensor.empty() : tensor<3x1x8x32xf32>
+  %pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<17x16xf32> -> tensor<3x1x8x32xf32>
+  return %pack : tensor<3x1x8x32xf32>
+}
+// CHECK-LABLE: func.func @nofold_pad_pack_artificial_padding(
+// CHECK:         tensor.pad
+// CHECK:         linalg.pack
+
+// -----
+
+func.func @nofold_pad_pack_with_nofold_attribute(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %src nofold low[0, 0] high[7, 0] {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst : f32
+  } : tensor<16649x16xf32> to tensor<16656x16xf32>
   %empty = tensor.empty() : tensor<2082x1x8x32xf32>
   %pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
       : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
   return %pack : tensor<2082x1x8x32xf32>
 }
-// CHECK-LABEL: func.func @nofold_pad_pack
+// CHECK-LABEL: func.func @nofold_pad_pack_with_nofold_attribute(
 // CHECK:         tensor.pad
 // CHECK:         linalg.pack
 
 // -----
 
 func.func @pad_pack_
diff erent_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
-  %c0 = arith.constant 0 : index
   %cst0 = arith.constant 0.000000e+00 : f32
   %cst1 = arith.constant 1.000000e+00 : f32
   %padded = tensor.pad %src low[0, 0] high[15, 0] {


        


More information about the Mlir-commits mailing list