[Mlir-commits] [mlir] [mlir][linalg] Add extra_pad_tiles to linalg.pack & unpack (PR #189049)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 27 09:27:08 PDT 2026


https://github.com/fabrizio-indirli created https://github.com/llvm/llvm-project/pull/189049

- In linalg.pack, add the optional `extra_pad_tiles` attribute to append a chosen number of additional full tiles of high-padding to each tiled dimension.
- In linalg.unpack, add the dual optional attribute `drop_last_tiles` to drop a chosen number of full outer tiles for each tiled dimension before reconstructing the unpacked result.

>From 557a4cb56be485b202f117ddf91f9a812ee0b2a8 Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <fabrizio.indirli at arm.com>
Date: Wed, 11 Mar 2026 17:39:01 +0000
Subject: [PATCH] [mlir][linalg] Add extra_pad_tiles to linalg.pack & unpack

- In linalg.pack, add the optional `extra_pad_tiles` attribute
  to append a chosen number of additional full tiles of high-padding
  to each tiled dimension.
- In linalg.unpack, add the dual optional attribute `drop_last_tiles` to
  drop a chosen number of full outer tiles for each tiled dimension before
  reconstructing the unpacked result.

Signed-off-by: Fabrizio Indirli <fabrizio.indirli at arm.com>
---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    |  53 ++--
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 255 +++++++++++++++---
 .../Transforms/DataLayoutPropagation.cpp      |  38 +--
 .../Transforms/PackAndUnpackPatterns.cpp      |  63 +++--
 .../Dialect/Linalg/Transforms/Transforms.cpp  |   8 +-
 mlir/python/mlir/dialects/linalg/__init__.py  |   4 +
 mlir/test/Dialect/Linalg/canonicalize.mlir    |  26 ++
 mlir/test/Dialect/Linalg/invalid.mlir         |  51 ++++
 mlir/test/Dialect/Linalg/roundtrip.mlir       |  13 +
 .../Dialect/Linalg/transform-lower-pack.mlir  |  49 ++++
 mlir/test/python/dialects/linalg/ops.py       |  10 +-
 11 files changed, 474 insertions(+), 96 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 95383e6262f71..cdecb41db3123 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -164,13 +164,16 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
       tiles divide perfectly the corresponding outer dimension in the result
       tensor. It is UB if the tile does not perfectly divide the dimension.
     - If present, it will pad along high dimensions (high-padding) to make the
-      tile complete. Note that it is not allowed to have artificial padding that
-      is not strictly required by linalg.pack (i.e., padding past what is needed
-      to complete the last tile along each packed dimension). It is UB if extra
-      padding is requested.
+      tile complete.
     It is not possible to verify the requirements statically with dynamic
     shapes, so they are treated as UB.
 
+    `extra_pad_tiles` (optional) specifies a number of additional full tiles of
+    high-padding to append for each tiled dimension. It is indexed in the same
+    order as `inner_dims_pos` / `inner_tiles`, must have the same length, and
+    defaults to all zeros when omitted. `extra_pad_tiles[i]` adds that many
+    extra full tiles at the end of tiled dimension `inner_dims_pos[i]`.
+
     Example:
     ```mlir
     %0 = linalg.pack %arg0 padding_value(%pad : f32) outer_dims_perm = [2, 1, 0]
@@ -200,7 +203,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
       DenseI64ArrayAttr:$inner_dims_pos, 
       Variadic<Index>:$inner_tiles,
-      DenseI64ArrayAttr:$static_inner_tiles);
+      DenseI64ArrayAttr:$static_inner_tiles,
+      DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$extra_pad_tiles);
   let results = (outs Optional<AnyRankedTensor>:$result);
 
   let builders = [
@@ -208,7 +212,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
       "ArrayRef<int64_t>":$innerDimsPos,
       "ArrayRef<OpFoldResult>":$innerTiles,
       CArg<"std::optional<Value>", "std::nullopt">:$paddingValue,
-      CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
+      CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm,
+      CArg<"ArrayRef<int64_t>", "{}">:$extraPadTiles)>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
@@ -218,28 +223,32 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     static SmallVector<OpFoldResult> getResultShape(OpBuilder &builder,
         Location loc, ArrayRef<OpFoldResult> sourceDims,
         ArrayRef<OpFoldResult> innerTileDims, ArrayRef<int64_t> innerDimsPos,
-        ArrayRef<int64_t> outerDimsPerm = {});
+        ArrayRef<int64_t> outerDimsPerm = {},
+        ArrayRef<int64_t> extraPadTiles = {});
 
     // Method to get the `RankedTensorType` of the result based on the inner
     // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
     // of outer loops (outerDimsPerm).
     static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
-        ArrayRef<int64_t> outerDimsPerm = {});
+        ArrayRef<int64_t> outerDimsPerm = {},
+        ArrayRef<int64_t> extraPadTiles = {});
 
     // Method to get the `MemRefType` of the result based on the inner
     // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
     // of outer loops (outerDimsPerm).
     static MemRefType inferPackedMemRefType(MemRefType sourceType,
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
-        ArrayRef<int64_t> outerDimsPerm = {});
+        ArrayRef<int64_t> outerDimsPerm = {},
+        ArrayRef<int64_t> extraPadTiles = {});
 
     // Returns the shape of the packed type. It is a shared helper that helps
     // type inference methods in a way that ensures that they agree on which 
     // dimensions are dynamic.
     static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
-        ArrayRef<int64_t> outerDimsPerm = {});
+        ArrayRef<int64_t> outerDimsPerm = {},
+        ArrayRef<int64_t> extraPadTiles = {});
 
     // Returns true if we have enough static information to catch undefined
     // behavior when the tile size does not divide perfectly the dimension of
@@ -249,7 +258,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
                                     ArrayRef<int64_t> innerDimsPos,
                                     ArrayRef<int64_t> outputShape,
                                     ArrayRef<int64_t> outerDimsPerm,
-                                    ArrayRef<OpFoldResult> innerTiles);
+                                    ArrayRef<OpFoldResult> innerTiles,
+                                    ArrayRef<int64_t> extraPadTiles = {});
 
     // Same as above function but here dynamic dimensions are assumed
     // to require padding.
@@ -257,11 +267,13 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
                                           ArrayRef<int64_t> innerDimsPos,
                                           ArrayRef<int64_t> outputShape,
                                           ArrayRef<int64_t> outerDimsPerm,
-                                          ArrayRef<OpFoldResult> innerTiles);
+                                          ArrayRef<OpFoldResult> innerTiles,
+                                          ArrayRef<int64_t> extraPadTiles = {});
 
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
-        ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+        ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
+        ArrayRef<int64_t> extraPadTiles = {});
 
     /// Build and return a new PackOp that is a clone of the current PackOp with
     /// (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
@@ -325,6 +337,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
     dimensions. If specified, it must have `n - k` elements. If specified, this
     permutation is applied before combining any dimensions.
 
+    `drop_last_tiles` (optional) specifies how many full packed outer tiles to
+    drop for each tiled dimension before reconstructing the unpacked result. It
+    is indexed in the same order as `inner_dims_pos` / `inner_tiles`, must have
+    the same length, and defaults to all zeros when omitted. This can be used
+    to drop the extra pad tiles added by a previous `pack` operation with `extra_pad_tiles`.
+
     Note, the unpack operation may drop any padding introduced by the pack
     operation and hence the following holds
     `NumElementsOf(source) >= NumElementsOf(result)`.
@@ -362,20 +380,23 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
       TensorOrMemRef<[AnyType]>:$dest,
       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
       DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
-      DenseI64ArrayAttr:$static_inner_tiles);
+      DenseI64ArrayAttr:$static_inner_tiles,
+      DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$drop_last_tiles);
   let results = (outs Optional<AnyRankedTensor>:$result);
 
   let builders = [
     OpBuilder<(ins "Value":$source, "Value":$dest,
     "ArrayRef<int64_t>":$innerDimsPos,
     "ArrayRef<OpFoldResult>":$innerTiles,
-    CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
+    CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm,
+    CArg<"ArrayRef<int64_t>", "{}">:$dropLastTiles)>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
-        ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+        ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
+        ArrayRef<int64_t> dropLastTiles = {});
 
     /// Build and return a new UnPackOp that is a clone of the current UnPackOp
     /// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e9698365765e7..8223ab78d1ef5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5104,6 +5104,51 @@ static SmallVector<int64_t> getStaticTilesImpl(OpTy op) {
   return staticTiles;
 }
 
+static inline bool isEmptyOrZeroArray(ArrayRef<int64_t> values) {
+  return values.empty() ||
+         llvm::all_of(values, [](int64_t value) { return value == 0; });
+}
+
+static ArrayRef<int64_t> getExtraTrailingTiles(PackOp op) {
+  return op.getExtraPadTiles();
+}
+
+static ArrayRef<int64_t> getExtraTrailingTiles(UnPackOp op) {
+  return op.getDropLastTiles();
+}
+
+static void addExtraTrailingTiles(SmallVectorImpl<int64_t> &outerDims,
+                                  ArrayRef<int64_t> innerDimsPos,
+                                  ArrayRef<int64_t> adjustments) {
+  if (adjustments.empty())
+    return;
+  for (auto [index, dimPos] : llvm::enumerate(innerDimsPos)) {
+    if (ShapedType::isDynamic(outerDims[dimPos]))
+      continue;
+    outerDims[dimPos] += adjustments[index];
+  }
+}
+
+static void addExtraTrailingTiles(OpBuilder &builder, Location loc,
+                                  SmallVectorImpl<OpFoldResult> &outerDims,
+                                  ArrayRef<int64_t> innerDimsPos,
+                                  ArrayRef<int64_t> adjustments) {
+  if (adjustments.empty())
+    return;
+  AffineExpr d0, c0;
+  bindDims(builder.getContext(), d0);
+  bindSymbols(builder.getContext(), c0);
+  auto addConstant = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/1, d0 + c0,
+                                    builder.getContext());
+  for (auto [index, dimPos] : llvm::enumerate(innerDimsPos)) {
+    if (adjustments[index] == 0)
+      continue;
+    outerDims[dimPos] = affine::makeComposedFoldedAffineApply(
+        builder, loc, addConstant,
+        {outerDims[dimPos], builder.getIndexAttr(adjustments[index])});
+  }
+}
+
 /// Returns true if `dimsPos` is invalid. It is invalid when:
 /// a) It contains duplicate.
 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
@@ -5161,12 +5206,26 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   size_t unpackedRank = unpackedType.getRank();
   ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
   ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
+  ArrayRef<int64_t> extraPadTiles = getExtraTrailingTiles(packOrUnPack);
   if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
     return op->emitError("invalid inner_dims_pos vector");
   if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
     return op->emitError("invalid outer_dims_perm vector");
   if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
     return op->emitError("outer_dims_perm must be a permutation or empty");
+  if (!extraPadTiles.empty() && extraPadTiles.size() != innerDimsPos.size()) {
+    return op->emitError() << (std::is_same<OpTy, PackOp>::value
+                                   ? "extra_pad_tiles"
+                                   : "drop_last_tiles")
+                           << " must have the same number of entries as "
+                              "inner_dims_pos";
+  }
+  if (llvm::any_of(extraPadTiles, [](int64_t value) { return value < 0; })) {
+    return op->emitError() << (std::is_same<OpTy, PackOp>::value
+                                   ? "extra_pad_tiles"
+                                   : "drop_last_tiles")
+                           << " must contain only non-negative values";
+  }
 
   // Tiling factors must be less than or equal to the input rank for pack (or
   // output rank for unpack), and must match the number of `inner_dims_pos`.
@@ -5196,7 +5255,8 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   // represents full tiles.
   SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
       unpackedType.getShape(), packOrUnPack.getStaticTiles(),
-      packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
+      packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm(),
+      extraPadTiles);
   for (auto it : llvm::enumerate(llvm::zip(
            packedType.getShape().take_back(mixedTiles.size()), mixedTiles))) {
     int64_t dimSize = std::get<0>(it.value());
@@ -5244,6 +5304,7 @@ struct PackOrUnPackTransposeResult {
   SmallVector<int64_t> innerDimsPos;
   SmallVector<OpFoldResult> innerTiles;
   SmallVector<int64_t> outerDimsPerm;
+  SmallVector<int64_t> trailingTileAdjustments;
 };
 } // namespace
 
@@ -5261,6 +5322,8 @@ commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
       SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
   metadata.innerTiles =
       SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
+  metadata.trailingTileAdjustments =
+      llvm::to_vector(getExtraTrailingTiles(packOrUnPackOp));
   int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
                              ? packOrUnPackOp.getSourceRank()
                              : packOrUnPackOp.getDestRank();
@@ -5274,6 +5337,9 @@ commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
            "invalid inner permutation");
     applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
     applyPermutationToVector(metadata.innerTiles, innerPermutation);
+    if (!metadata.trailingTileAdjustments.empty())
+      applyPermutationToVector(metadata.trailingTileAdjustments,
+                               innerPermutation);
   }
   if (!outerPermutation.empty()) {
     assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
@@ -5299,7 +5365,7 @@ ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand> paddingValue;
   SmallVector<Type> paddingValueType;
   SmallVector<int64_t> staticTiles;
-  DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
+  DenseI64ArrayAttr innerDimsPos, outerDimsPerm, extraPadTiles;
   Type sourceType, destType, resultType;
 
   if (parser.parseOperand(source))
@@ -5352,6 +5418,24 @@ ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
   for (auto val : staticTilesAttr.asArrayRef())
     staticTiles.push_back(val);
 
+  if (succeeded(parser.parseOptionalKeyword("extra_pad_tiles"))) {
+    if (parser.parseEqual())
+      return failure();
+    SmallVector<int64_t> extraPadTilesVec;
+    if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+          int64_t value;
+          if (parser.parseInteger(value))
+            return failure();
+          extraPadTilesVec.push_back(value);
+          return success();
+        }))
+      return failure();
+    if (!isEmptyOrZeroArray(extraPadTilesVec)) {
+      extraPadTiles =
+          parser.getBuilder().getDenseI64ArrayAttr(extraPadTilesVec);
+    }
+  }
+
   if (parser.parseKeyword("into") || parser.parseOperand(dest))
     return failure();
 
@@ -5395,6 +5479,8 @@ ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute("inner_dims_pos", innerDimsPos);
   if (outerDimsPerm)
     result.addAttribute("outer_dims_perm", outerDimsPerm);
+  if (extraPadTiles)
+    result.addAttribute("extra_pad_tiles", extraPadTiles);
 
   SmallVector<int32_t> segmentSizes = {
       1, 1, static_cast<int32_t>(paddingValue.size()),
@@ -5429,11 +5515,18 @@ void PackOp::print(OpAsmPrinter &p) {
   p << " inner_tiles = ";
   printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
 
+  if (!isEmptyOrZeroArray(getExtraPadTiles())) {
+    p << " extra_pad_tiles = [";
+    llvm::interleaveComma(getExtraPadTiles(), p);
+    p << "]";
+  }
+
   p << " into " << getDest();
 
   p.printOptionalAttrDict((*this)->getAttrs(),
                           {"static_inner_tiles", "inner_dims_pos",
-                           "outer_dims_perm", "operandSegmentSizes"});
+                           "outer_dims_perm", "extra_pad_tiles",
+                           "operandSegmentSizes"});
 
   p << " : " << getSource().getType();
   p << " -> " << getDest().getType();
@@ -5443,7 +5536,8 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
                    Value dest, ArrayRef<int64_t> innerDimsPos,
                    ArrayRef<OpFoldResult> innerTiles,
                    std::optional<Value> paddingValue,
-                   ArrayRef<int64_t> outerDimsPerm) {
+                   ArrayRef<int64_t> outerDimsPerm,
+                   ArrayRef<int64_t> extraPadTiles) {
   assert(innerDimsPos.size() == innerTiles.size() &&
          "number of tile sizes specified must match the specified number of "
          "original dimensions to be tiled");
@@ -5455,7 +5549,10 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
         outerDimsPerm.empty() ? nullptr
                               : builder.getDenseI64ArrayAttr(outerDimsPerm),
         builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
-        builder.getDenseI64ArrayAttr(staticTileSizes));
+        builder.getDenseI64ArrayAttr(staticTileSizes),
+        isEmptyOrZeroArray(extraPadTiles)
+            ? nullptr
+            : builder.getDenseI64ArrayAttr(extraPadTiles));
 }
 
 LogicalResult
@@ -5504,7 +5601,10 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
                                  ArrayRef<int64_t> innerDimsPos,
                                  ArrayRef<int64_t> outputShape,
                                  ArrayRef<int64_t> outerDimsPerm,
-                                 ArrayRef<OpFoldResult> innerTiles) {
+                                 ArrayRef<OpFoldResult> innerTiles,
+                                 ArrayRef<int64_t> extraPadTiles) {
+  if (!isEmptyOrZeroArray(extraPadTiles))
+    return true;
   SmallVector<int64_t> outputTileSizes(
       outputShape.take_front(inputShape.size()));
   if (!outerDimsPerm.empty()) {
@@ -5535,7 +5635,10 @@ bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
                                        ArrayRef<int64_t> innerDimsPos,
                                        ArrayRef<int64_t> outputShape,
                                        ArrayRef<int64_t> outerDimsPerm,
-                                       ArrayRef<OpFoldResult> innerTiles) {
+                                       ArrayRef<OpFoldResult> innerTiles,
+                                       ArrayRef<int64_t> extraPadTiles) {
+  if (!isEmptyOrZeroArray(extraPadTiles))
+    return true;
   SmallVector<int64_t> outputTileSizes(
       outputShape.take_front(inputShape.size()));
   if (!outerDimsPerm.empty()) {
@@ -5576,7 +5679,7 @@ LogicalResult PackOp::verify() {
   if (!paddingValue &&
       requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
                           getDestType().getShape(), getOuterDimsPerm(),
-                          getMixedTiles())) {
+                          getMixedTiles(), getExtraPadTiles())) {
     return emitOpError(
         "invalid tile factor or output size provided. Only full tiles are "
         "supported when padding_value is not set");
@@ -5602,7 +5705,8 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
 SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
                                               ArrayRef<int64_t> innerTileSizes,
                                               ArrayRef<int64_t> innerDimsPos,
-                                              ArrayRef<int64_t> outerDimsPerm) {
+                                              ArrayRef<int64_t> outerDimsPerm,
+                                              ArrayRef<int64_t> extraPadTiles) {
   SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
   for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
     if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
@@ -5614,6 +5718,7 @@ SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
     resultShape[tiledDim.value()] = llvm::divideCeilSigned(
         resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
   }
+  addExtraTrailingTiles(resultShape, innerDimsPos, extraPadTiles);
 
   // Swap tile loops if outer_dims_perm is available.
   if (!outerDimsPerm.empty())
@@ -5627,7 +5732,7 @@ SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
 SmallVector<OpFoldResult> PackOp::getResultShape(
     OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
     ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
-    ArrayRef<int64_t> outerDimsPerm) {
+    ArrayRef<int64_t> outerDimsPerm, ArrayRef<int64_t> extraPadTiles) {
   SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
 
   AffineExpr s0, s1;
@@ -5638,6 +5743,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
         builder, loc, ceilDivExpr,
         {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
   }
+  addExtraTrailingTiles(builder, loc, resultDims, innerDimsPos, extraPadTiles);
   if (!outerDimsPerm.empty())
     applyPermutationToVector(resultDims, outerDimsPerm);
   resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
@@ -5645,7 +5751,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
   SmallVector<int64_t> resultTypeShape =
       inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
                        asShapeWithAnyValueAsDynamic(innerTileSizes),
-                       innerDimsPos, outerDimsPerm);
+                       innerDimsPos, outerDimsPerm, extraPadTiles);
 
   // Fix-up `resultDims` to ensure that they are Value's if and only if the
   // result type shape says it's a dynamic dim. This is needed as callers may
@@ -5663,25 +5769,30 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
 
 RankedTensorType PackOp::inferPackedTensorType(
     RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
-    ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
-  SmallVector<int64_t> resultShape = inferPackedShape(
-      sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
+    ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
+    ArrayRef<int64_t> extraPadTiles) {
+  SmallVector<int64_t> resultShape =
+      inferPackedShape(sourceType.getShape(), innerTileSizes, innerDimsPos,
+                       outerDimsPerm, extraPadTiles);
   return RankedTensorType::get(resultShape, sourceType.getElementType());
 }
 
 MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
                                          ArrayRef<int64_t> innerTileSizes,
                                          ArrayRef<int64_t> innerDimsPos,
-                                         ArrayRef<int64_t> outerDimsPerm) {
-  SmallVector<int64_t> resultShape = inferPackedShape(
-      sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
+                                         ArrayRef<int64_t> outerDimsPerm,
+                                         ArrayRef<int64_t> extraPadTiles) {
+  SmallVector<int64_t> resultShape =
+      inferPackedShape(sourceType.getShape(), innerTileSizes, innerDimsPos,
+                       outerDimsPerm, extraPadTiles);
   return MemRefType::get(resultShape, sourceType.getElementType());
 }
 
 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
                                       ArrayRef<OpFoldResult> innerTileSizes,
                                       ArrayRef<int64_t> innerDimsPos,
-                                      ArrayRef<int64_t> outerDimsPerm) {
+                                      ArrayRef<int64_t> outerDimsPerm,
+                                      ArrayRef<int64_t> extraPadTiles) {
   AffineExpr dim0, dim1;
   bindDims(b.getContext(), dim0, dim1);
   auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
@@ -5703,6 +5814,7 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
     OpFoldResult tileSize = std::get<1>(it);
     mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
   }
+  addExtraTrailingTiles(b, loc, mixedSizes, innerDimsPos, extraPadTiles);
   if (!outerDimsPerm.empty())
     applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
 
@@ -5716,12 +5828,13 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
                                      ArrayRef<int64_t> outerPermutation) {
   PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
       *this, innerPermutation, outerPermutation);
-  Value transposedDest =
-      createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
-                              metadata.innerDimsPos, metadata.outerDimsPerm);
+  Value transposedDest = createDestinationTensor(
+      b, loc, getSource(), metadata.innerTiles, metadata.innerDimsPos,
+      metadata.outerDimsPerm, metadata.trailingTileAdjustments);
   return PackOp::create(b, loc, getSource(), transposedDest,
                         metadata.innerDimsPos, metadata.innerTiles,
-                        getPaddingValue(), metadata.outerDimsPerm);
+                        getPaddingValue(), metadata.outerDimsPerm,
+                        metadata.trailingTileAdjustments);
 }
 
 template <typename OpTy>
@@ -5810,6 +5923,10 @@ static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
          isIdentityPermutation(unPackOp.getOuterDimsPerm());
 }
 
+static bool hasSameTrailingTileAdjustments(PackOp packOp, UnPackOp unPackOp) {
+  return packOp.getExtraPadTiles() == unPackOp.getDropLastTiles();
+}
+
 // Return true if pack and unpack have the same tiles.
 // Same SSA values or same integer constants.
 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
@@ -5834,7 +5951,7 @@ static bool paddingIsNotNeeded(PackOp op) {
     return false;
   return !PackOp::requirePaddingValue(
       srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
-      op.getOuterDimsPerm(), op.getMixedTiles());
+      op.getOuterDimsPerm(), op.getMixedTiles(), op.getExtraPadTiles());
 }
 
 /// Returns true if the `srcShape` or `destShape` is different from the one in
@@ -5883,6 +6000,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     if (unPackOp.getSourceType() == packOp.getDestType() &&
         !packOp.getPaddingValue() &&
         hasSameInnerOuterAttribute(packOp, unPackOp) &&
+        hasSameTrailingTileAdjustments(packOp, unPackOp) &&
         haveSameTiles(packOp, unPackOp)) {
       rewriter.replaceOp(packOp, unPackOp.getSource());
       return success();
@@ -5938,6 +6056,8 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
   static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
                     std::is_same<PackOrUnpackOp, UnPackOp>::value,
                 "Function meant for pack/unpack");
+  if (!isEmptyOrZeroArray(getExtraTrailingTiles(packOp)))
+    return false;
   // This is a pad if packing only adds ones and we don't transpose dimensions.
 
   // Check that we are not transposing any dimensions.
@@ -6028,10 +6148,10 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
     // TODO: Strictly speaking, discardable attributes should be _discarded_ at
     // this point. However, in practice, we use them for things that we'd like
     // to preserve. Implement a better abstraction.
-    PackOp newOp =
-        PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
-                       op.getInnerDimsPos(), newMixedTileSizes,
-                       op.getPaddingValue(), op.getOuterDimsPerm());
+    PackOp newOp = PackOp::create(rewriter, op.getLoc(), newOperands[0],
+                                  newOperands[1], op.getInnerDimsPos(),
+                                  newMixedTileSizes, op.getPaddingValue(),
+                                  op.getOuterDimsPerm(), op.getExtraPadTiles());
     newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
 
     // Replace op.
@@ -6064,7 +6184,7 @@ ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::UnresolvedOperand source, dest;
   SmallVector<OpAsmParser::UnresolvedOperand> dynamicTiles;
   SmallVector<int64_t> staticTiles;
-  DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
+  DenseI64ArrayAttr innerDimsPos, outerDimsPerm, dropLastTiles;
   Type sourceType, destType, resultType;
 
   if (parser.parseOperand(source))
@@ -6109,6 +6229,24 @@ ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
   for (auto val : staticTilesAttr.asArrayRef())
     staticTiles.push_back(val);
 
+  if (succeeded(parser.parseOptionalKeyword("drop_last_tiles"))) {
+    if (parser.parseEqual())
+      return failure();
+    SmallVector<int64_t> dropLastTilesVec;
+    if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+          int64_t value;
+          if (parser.parseInteger(value))
+            return failure();
+          dropLastTilesVec.push_back(value);
+          return success();
+        }))
+      return failure();
+    if (!isEmptyOrZeroArray(dropLastTilesVec)) {
+      dropLastTiles =
+          parser.getBuilder().getDenseI64ArrayAttr(dropLastTilesVec);
+    }
+  }
+
   if (parser.parseKeyword("into") || parser.parseOperand(dest))
     return failure();
 
@@ -6147,6 +6285,8 @@ ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute("inner_dims_pos", innerDimsPos);
   if (outerDimsPerm)
     result.addAttribute("outer_dims_perm", outerDimsPerm);
+  if (dropLastTiles)
+    result.addAttribute("drop_last_tiles", dropLastTiles);
 
   SmallVector<int32_t> segmentSizes = {
       1, 1, 0, static_cast<int32_t>(dynamicTiles.size())};
@@ -6175,11 +6315,18 @@ void UnPackOp::print(OpAsmPrinter &p) {
   p << " inner_tiles = ";
   printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
 
+  if (!isEmptyOrZeroArray(getDropLastTiles())) {
+    p << " drop_last_tiles = [";
+    llvm::interleaveComma(getDropLastTiles(), p);
+    p << "]";
+  }
+
   p << " into " << getDest();
 
   p.printOptionalAttrDict((*this)->getAttrs(),
                           {"static_inner_tiles", "inner_dims_pos",
-                           "outer_dims_perm", "operandSegmentSizes"});
+                           "outer_dims_perm", "drop_last_tiles",
+                           "operandSegmentSizes"});
 
   p << " : " << getSource().getType();
   p << " -> " << getDest().getType();
@@ -6244,7 +6391,8 @@ Speculation::Speculatability UnPackOp::getSpeculatability() {
 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
                      Value dest, ArrayRef<int64_t> innerDimsPos,
                      ArrayRef<OpFoldResult> innerTiles,
-                     ArrayRef<int64_t> outerDimsPerm) {
+                     ArrayRef<int64_t> outerDimsPerm,
+                     ArrayRef<int64_t> dropLastTiles) {
   assert(innerDimsPos.size() == innerTiles.size() &&
          "number of tile sizes specified must match the specified number of "
          "original dimensions to be tiled");
@@ -6255,14 +6403,18 @@ void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
         outerDimsPerm.empty() ? nullptr
                               : builder.getDenseI64ArrayAttr(outerDimsPerm),
         builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
-        builder.getDenseI64ArrayAttr(staticTileSizes));
+        builder.getDenseI64ArrayAttr(staticTileSizes),
+        isEmptyOrZeroArray(dropLastTiles)
+            ? nullptr
+            : builder.getDenseI64ArrayAttr(dropLastTiles));
 }
 
 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
                                         Value source,
                                         ArrayRef<OpFoldResult> innerTileSizes,
                                         ArrayRef<int64_t> innerDimsPos,
-                                        ArrayRef<int64_t> outerDimsPerm) {
+                                        ArrayRef<int64_t> outerDimsPerm,
+                                        ArrayRef<int64_t> dropLastTiles) {
   AffineExpr sym0, sym1;
   bindSymbols(b.getContext(), sym0, sym1);
   auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
@@ -6284,6 +6436,21 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
         mixedSizes, invertPermutationVector(outerDimsPerm));
   }
 
+  if (!dropLastTiles.empty()) {
+    AffineExpr d0, c0;
+    bindDims(b.getContext(), d0);
+    bindSymbols(b.getContext(), c0);
+    auto subConstant = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/1,
+                                      d0 - c0, b.getContext());
+    for (auto [index, dimPos] : llvm::enumerate(innerDimsPos)) {
+      if (dropLastTiles[index] == 0)
+        continue;
+      mixedSizes[dimPos] = affine::makeComposedFoldedAffineApply(
+          b, loc, subConstant,
+          {mixedSizes[dimPos], b.getIndexAttr(dropLastTiles[index])});
+    }
+  }
+
   for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
     mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
 
@@ -6299,7 +6466,8 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
       *this, innerPermutation, outerPermutation);
   return UnPackOp::create(b, loc, transposedSource, getDest(),
                           metadata.innerDimsPos, metadata.innerTiles,
-                          metadata.outerDimsPerm);
+                          metadata.outerDimsPerm,
+                          metadata.trailingTileAdjustments);
 }
 
 /// Returns true if the `srcShape` or `destShape` is different from the one in
@@ -6350,6 +6518,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
       return failure();
     if (packOp.getPaddingValue() ||
         !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+        !hasSameTrailingTileAdjustments(packOp, unPackOp) ||
         !haveSameTiles(packOp, unPackOp))
       return failure();
     rewriter.replaceOp(unPackOp, packOp.getSource());
@@ -6402,7 +6571,8 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
     }
     UnPackOp newOp = UnPackOp::create(
         rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
-        unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
+        unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm(),
+        unPackOp.getDropLastTiles());
     rewriter.replaceOpWithNewOp<tensor::CastOp>(
         unPackOp, unPackOp.getResult().getType(), newOp.getResult());
     return success();
@@ -6422,8 +6592,9 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
   SmallVector<int64_t> outerShapeWithoutTranspose =
       getPackedOuterShapeWithoutTransposition(*this);
   SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(), false);
-  for (auto [pos, tileSize] :
-       llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
+  for (auto [index, values] : llvm::enumerate(llvm::zip_equal(
+           this->getInnerDimsPos(), this->getStaticInnerTiles()))) {
+    auto [pos, tileSize] = values;
     areOuterDimsTiled[pos] = true;
     if (unpackedTypeAfterFold.isDynamicDim(pos))
       return false;
@@ -6431,6 +6602,11 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
       return false;
     if (ShapedType::isDynamic(tileSize))
       return false;
+    if (!getDropLastTiles().empty()) {
+      if (outerShapeWithoutTranspose[pos] < getDropLastTiles()[index])
+        return false;
+      outerShapeWithoutTranspose[pos] -= getDropLastTiles()[index];
+    }
     int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
                           unpackedTypeAfterFold.getDimSize(pos);
     if (paddingSize >= tileSize)
@@ -6509,9 +6685,10 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
     // TODO: Strictly speaking, discardable attributes should be _discarded_ at
     // this point. However, in practice, we use them for things that we'd like
     // to preserve. Implement a better abstraction.
-    UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
-                                      newOperands[1], op.getInnerDimsPos(),
-                                      newMixedTileSizes, op.getOuterDimsPerm());
+    UnPackOp newOp =
+        UnPackOp::create(rewriter, op.getLoc(), sourceTensor, newOperands[1],
+                         op.getInnerDimsPos(), newMixedTileSizes,
+                         op.getOuterDimsPerm(), op.getDropLastTiles());
     newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
 
     // Replace op.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index d36ca43a6cbb3..c6dcbe62fc1d6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -365,6 +365,7 @@ packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
   auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
     return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
            packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
+           packOp.getExtraPadTiles() == unPackOp.getDropLastTiles() &&
            llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
   };
   DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
@@ -642,10 +643,10 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
     SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
     auto empty = linalg::PackOp::createDestinationTensor(
         rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
-        outerDimsPerm);
+        outerDimsPerm, packOp.getExtraPadTiles());
     auto sourcePack = linalg::PackOp::create(
         rewriter, loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
-        /*padding=*/std::nullopt, outerDimsPerm);
+        /*padding=*/std::nullopt, outerDimsPerm, packOp.getExtraPadTiles());
 
     // If we have `outer_dims_perms` we need to adjust the padded dimensions.
     SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
@@ -668,10 +669,11 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
     // replace the other uses.
     if (!padOp->hasOneUse()) {
       auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
-          rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
-      UnPackOp unpackedPad =
-          linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty,
-                                   innerDimsPos, mixedTiles, outerDimsPerm);
+          rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm,
+          packOp.getExtraPadTiles());
+      UnPackOp unpackedPad = linalg::UnPackOp::create(
+          rewriter, loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles,
+          outerDimsPerm, packOp.getExtraPadTiles());
       rewriter.replaceAllUsesExcept(padOp, unpackedPad.getResult(), sourcePack);
     }
 
@@ -803,11 +805,11 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
 
   auto emptyOp = linalg::PackOp::createDestinationTensor(
       rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
-      projectedInnerDimsPos, newOuterDimsPerm);
+      projectedInnerDimsPos, newOuterDimsPerm, packOp.getExtraPadTiles());
   auto newPackOp = linalg::PackOp::create(
       rewriter, packOp.getLoc(), collapseOp.getSrc(), emptyOp,
       projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
-      newOuterDimsPerm);
+      newOuterDimsPerm, packOp.getExtraPadTiles());
 
   SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
   // First apply the permutation on the reassociations of the outer dims.
@@ -933,7 +935,8 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
   // dimensions or packing extending dimensions.
   RankedTensorType newPackType = linalg::PackOp::inferPackedTensorType(
       expandOp.getSrcType(), packOp.getStaticInnerTiles(),
-      projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
+      projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{},
+      packOp.getExtraPadTiles());
   auto reassocExpand =
       getReassociationIndicesForReshape(newPackType, packOp.getDestType());
   if (!reassocExpand)
@@ -942,11 +945,12 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
 
   Value destTensor = linalg::PackOp::createDestinationTensor(
       rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
-      projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
+      projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{},
+      packOp.getExtraPadTiles());
   PackOp packedVal = linalg::PackOp::create(
       rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor,
       projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
-      /*outerDimsPerm=*/SmallVector<int64_t>{});
+      /*outerDimsPerm=*/SmallVector<int64_t>{}, packOp.getExtraPadTiles());
 
   Value newExpandOp = tensor::ExpandShapeOp::create(
       rewriter, packOp.getLoc(), packOp.getDestType(), packedVal.getResult(),
@@ -1068,17 +1072,19 @@ static LogicalResult pushDownUnPackOpThroughExpandShape(
   }
 
   RankedTensorType newExpandType = linalg::PackOp::inferPackedTensorType(
-      expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
+      expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm,
+      unPackOp.getDropLastTiles());
   auto newExpandOp =
       tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType,
                                     unPackOp.getSource(), newReassocIndices);
 
   auto emptyOp = linalg::UnPackOp::createDestinationTensor(
       rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
-      projectedInnerDimsPos, newOuterDimsPerm);
+      projectedInnerDimsPos, newOuterDimsPerm, unPackOp.getDropLastTiles());
   auto newUnPackOp = linalg::UnPackOp::create(
       rewriter, unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
-      projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
+      projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm,
+      unPackOp.getDropLastTiles());
   rewriter.replaceOp(expandOp, newUnPackOp);
 
   return success();
@@ -1252,7 +1258,7 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
   Value unPackOpRes =
       linalg::UnPackOp::create(rewriter, genericOp.getLoc(), newResult,
                                destPack.getSource(), innerDimsPos, mixedTiles,
-                               outerDimsPerm)
+                               outerDimsPerm, destPack.getExtraPadTiles())
           .getResult();
 
   return std::make_tuple(newGenericOp, unPackOpRes);
@@ -1340,7 +1346,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
 
     UnPackOp replacement = linalg::UnPackOp::create(
         rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
-        unpackOp.getMixedTiles(), outerDimsPerm);
+        unpackOp.getMixedTiles(), outerDimsPerm, unpackOp.getDropLastTiles());
     rewriter.replaceOp(padOp, replacement);
     return success();
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 993eae62535c3..e513443687ffa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -180,6 +180,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
     // TODO: Support Memref UnPackOp. Temporarily return failure.
     if (!unpackOp.hasPureTensorSemantics())
       return failure();
+    if (!unpackOp.getDropLastTiles().empty())
+      return rewriter.notifyMatchFailure(unpackOp,
+                                         "expects no drop_last_tiles");
 
     ShapedType destType = unpackOp.getDestType();
     if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
@@ -236,9 +239,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
     ShapedType unpackedType = packOp.getSourceType();
     SmallVector<int64_t> outerShapeWithoutTranspose =
         getPackedOuterShapeWithoutTransposition(packOp);
-    for (auto [pos, tileSize, high] :
-         llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
-                         padOp.getMixedHighPad())) {
+    for (auto [index, tuple] : llvm::enumerate(llvm::zip_equal(
+             packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
+             padOp.getMixedHighPad()))) {
+      auto [pos, tileSize, high] = tuple;
       if (unpackedType.isDynamicDim(pos))
         return failure();
       if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
@@ -248,17 +252,19 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
       std::optional<int64_t> cstHigh = getConstantIntValue(high);
       if (!cstHigh)
         return failure();
-      int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
-                            unpackedType.getDimSize(pos);
-      // Do not fold the op if it requires artificial padding.
-      if (paddingSize + cstHigh.value() >= tileSize)
+      int64_t originalDim = unpackedType.getDimSize(pos) - cstHigh.value();
+      int64_t baseOuter = llvm::divideCeilSigned(originalDim, tileSize);
+      int64_t adjustedOuter = baseOuter;
+      if (!packOp.getExtraPadTiles().empty())
+        adjustedOuter += packOp.getExtraPadTiles()[index];
+      if (adjustedOuter != outerShapeWithoutTranspose[pos])
         return failure();
     }
 
     rewriter.replaceOpWithNewOp<PackOp>(
         packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
-        packOp.getMixedTiles(), constantPaddingValue,
-        packOp.getOuterDimsPerm());
+        packOp.getMixedTiles(), constantPaddingValue, packOp.getOuterDimsPerm(),
+        packOp.getExtraPadTiles());
     return success();
   }
 
@@ -299,7 +305,8 @@ struct FoldUnpackWithExtractSliceOp
         rewriter, sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
     rewriter.replaceOpWithNewOp<UnPackOp>(
         sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
-        unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
+        unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm(),
+        unpackOp.getDropLastTiles());
     return success();
   }
 
@@ -368,6 +375,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     SmallVector<int64_t> newOuterDimsPermVec;
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
+    SmallVector<int64_t> newExtraPadTilesVec;
     int64_t srcRank = packOp.getSourceRank();
 
     if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
@@ -382,15 +390,19 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
       int64_t remappedPosition = transposePerm[i] - srcRank;
       newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
       newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
+      if (!packOp.getExtraPadTiles().empty())
+        newExtraPadTilesVec.push_back(
+            packOp.getExtraPadTiles()[remappedPosition]);
     }
 
     Value output = packOp.createDestinationTensor(
         rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
-        newInnerDimsPosVec, newOuterDimsPermVec);
+        newInnerDimsPosVec, newOuterDimsPermVec, newExtraPadTilesVec);
 
     rewriter.replaceOpWithNewOp<PackOp>(
         linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
-        newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
+        newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec,
+        newExtraPadTilesVec);
 
     return success();
   }
@@ -432,6 +444,7 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
     auto outerDimsPerm = packOp.getOuterDimsPerm();
     auto innerDimsPos = packOp.getInnerDimsPos();
     SmallVector<int64_t> newInnerDimsPosVec;
+    SmallVector<int64_t> newExtraPadTilesVec;
     SmallVector<int64_t> newOuterDimsPermVec =
         llvm::to_vector(transposePermutation);
 
@@ -440,16 +453,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 
     // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
     // permutation rank won't necessarily be equal in all cases.
-    for (auto dim : innerDimsPos)
+    for (auto [index, dim] : llvm::enumerate(innerDimsPos)) {
       newInnerDimsPosVec.push_back(transposePermutation[dim]);
+      if (!packOp.getExtraPadTiles().empty())
+        newExtraPadTilesVec.push_back(packOp.getExtraPadTiles()[index]);
+    }
 
     Value output = packOp.createDestinationTensor(
         rewriter, packOp.getLoc(), linalgOp->getOperand(0),
-        packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
+        packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec,
+        newExtraPadTilesVec);
 
     rewriter.replaceOpWithNewOp<PackOp>(
         packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
-        packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
+        packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec,
+        newExtraPadTilesVec);
 
     return success();
   }
@@ -492,13 +510,17 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
     auto innerDimsPos = unPackOp.getInnerDimsPos();
     SmallVector<int64_t> newInnerDimsPosVec;
+    SmallVector<int64_t> newDropLastTilesVec;
     SmallVector<int64_t> newOuterDimsPermVec =
         invertPermutationVector(maybePerm.value());
 
     // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
     // permutation rank won't necessarily be equal in all cases.
-    for (auto dim : innerDimsPos)
+    for (auto [index, dim] : llvm::enumerate(innerDimsPos)) {
       newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
+      if (!unPackOp.getDropLastTiles().empty())
+        newDropLastTilesVec.push_back(unPackOp.getDropLastTiles()[index]);
+    }
 
     if (!outerDimsPerm.empty())
       applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
@@ -506,7 +528,8 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
     // Reuse the destination of the transpose op.
     rewriter.replaceOpWithNewOp<UnPackOp>(
         linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
-        newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
+        newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec,
+        newDropLastTilesVec);
 
     return success();
   }
@@ -559,6 +582,7 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     SmallVector<int64_t> newOuterDimsPermVec;
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
+    SmallVector<int64_t> newDropLastTilesVec;
     if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
                          newOuterDimsPermVec, destRank))
       return rewriter.notifyMatchFailure(
@@ -571,6 +595,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
       int64_t remappedPosition = inverseTransposePerm[i] - destRank;
       newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
       newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
+      if (!unPackOp.getDropLastTiles().empty())
+        newDropLastTilesVec.push_back(
+            unPackOp.getDropLastTiles()[remappedPosition]);
     }
 
     auto elemType =
@@ -580,7 +607,7 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
 
     rewriter.replaceOpWithNewOp<UnPackOp>(
         unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
-        newMixedInnerTilesVec, newOuterDimsPermVec);
+        newMixedInnerTilesVec, newOuterDimsPermVec, newDropLastTilesVec);
 
     return success();
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 260e36fb47f04..74ecfb5ccc3e2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -549,8 +549,8 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
       if (areConstantTiles && operandType.hasStaticShape() &&
           !linalg::PackOp::requirePaddingValue(
               operandType.getShape(), innerPos,
-              cast<ShapedType>(dest.getType()).getShape(), {},
-              innerPackSizes)) {
+              cast<ShapedType>(dest.getType()).getShape(), {}, innerPackSizes,
+              {})) {
         packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
                                                  innerPos, innerPackSizes));
       } else {
@@ -588,7 +588,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
     // Build the symmetrical UnPackOp to the existing PackOp.
     unPackOps.push_back(linalg::UnPackOp::create(
         rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
-        maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
+        maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles(),
+        maybePackedInit.getOuterDimsPerm(),
+        maybePackedInit.getExtraPadTiles()));
     results.push_back(unPackOps.back().getResult());
   }
 
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 1f6d1a68fbbb8..f6414c934e2f5 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -299,6 +299,7 @@ def pack(
     *,
     padding_value=None,
     outer_dims_perm=None,
+    extra_pad_tiles=None,
     loc=None,
     ip=None,
 ) -> ir.Value:
@@ -321,6 +322,7 @@ def pack(
             static_inner_tiles=static_inner_tiles,
             padding_value=padding_value,
             outer_dims_perm=outer_dims_perm,
+            extra_pad_tiles=extra_pad_tiles,
             loc=loc,
             ip=ip,
         )
@@ -334,6 +336,7 @@ def unpack(
     inner_tiles,
     *,
     outer_dims_perm=None,
+    drop_last_tiles=None,
     loc=None,
     ip=None,
 ) -> ir.Value:
@@ -354,6 +357,7 @@ def unpack(
             inner_tiles=dynamic_inner_tiles,
             static_inner_tiles=static_inner_tiles,
             outer_dims_perm=outer_dims_perm,
+            drop_last_tiles=drop_last_tiles,
             loc=loc,
             ip=ip,
         )
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 0c5a1c6108ae3..bba1261d02289 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -2158,6 +2158,32 @@ func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
 
 // -----
 
+// CHECK-LABEL: func.func @keep_unpack_pack_tensor_with_tile_adjustments
+// CHECK: linalg.unpack
+// CHECK: linalg.pack
+func.func @keep_unpack_pack_tensor_with_tile_adjustments(%x: tensor<17x10x8x32xf32>, %dest: tensor<127x255xf32>) -> tensor<17x10x8x32xf32> {
+  %unpacked = linalg.unpack %x inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [1, 2]
+             into %dest : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+  %cst = arith.constant 0.0 : f32
+  %packed = linalg.pack %unpacked padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, 2]
+             into %x : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+  return %packed : tensor<17x10x8x32xf32>
+}
+
+// CHECK-LABEL: func.func @do_not_fold_unpack_pack_tensor_with_mismatched_tile_adjustments
+// CHECK: linalg.unpack
+// CHECK: linalg.pack
+func.func @do_not_fold_unpack_pack_tensor_with_mismatched_tile_adjustments(%x: tensor<17x10x8x32xf32>, %dest: tensor<127x255xf32>, %packed_dest: tensor<17x9x8x32xf32>) -> tensor<17x9x8x32xf32> {
+  %unpacked = linalg.unpack %x inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [1, 2]
+             into %dest : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+  %cst = arith.constant 0.0 : f32
+  %packed = linalg.pack %unpacked padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, 1]
+             into %packed_dest : tensor<127x255xf32> -> tensor<17x9x8x32xf32>
+  return %packed : tensor<17x9x8x32xf32>
+}
+
+// -----
+
 // Test that pack/unpack canonicalization is disabled for memref versions.
 // CHECK-LABEL: func.func @negative_pack_unpack_memref_no_canonicalization
 // CHECK: linalg.pack
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index d9192cbda14e7..75347fe28ef87 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1911,6 +1911,32 @@ func.func @pack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: ten
 
 // -----
 
+func.func @pack_invalid_extra_pad_tiles_length(%source: tensor<128x256xf32>, %dest: tensor<17x8x8x32xf32>) -> tensor<17x8x8x32xf32> {
+  %pad = arith.constant 0.0 : f32
+  // expected-error at +1 {{extra_pad_tiles must have the same number of entries as inner_dims_pos}}
+  %0 = linalg.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1] into %dest : tensor<128x256xf32> -> tensor<17x8x8x32xf32>
+  return %0 : tensor<17x8x8x32xf32>
+}
+
+// -----
+
+func.func @pack_invalid_negative_extra_pad_tiles(%source: tensor<128x256xf32>, %dest: tensor<17x8x8x32xf32>) -> tensor<17x8x8x32xf32> {
+  %pad = arith.constant 0.0 : f32
+  // expected-error at +1 {{extra_pad_tiles must contain only non-negative values}}
+  %0 = linalg.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, -1] into %dest : tensor<128x256xf32> -> tensor<17x8x8x32xf32>
+  return %0 : tensor<17x8x8x32xf32>
+}
+
+// -----
+
+func.func @pack_invalid_extra_pad_tiles_without_padding(%source: tensor<128x256xf32>, %dest: tensor<17x8x8x32xf32>) -> tensor<17x8x8x32xf32> {
+  // expected-error at +1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
+  %0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, 0] into %dest : tensor<128x256xf32> -> tensor<17x8x8x32xf32>
+  return %0 : tensor<17x8x8x32xf32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // linalg.unpack
 //===----------------------------------------------------------------------===//
@@ -1939,6 +1965,22 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
 
 // -----
 
+func.func @unpack_invalid_drop_last_tiles_length(%source: tensor<17x8x8x32xf32>, %dest: tensor<128x256xf32>) -> tensor<128x256xf32> {
+  // expected-error at +1 {{drop_last_tiles must have the same number of entries as inner_dims_pos}}
+  %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [1] into %dest : tensor<17x8x8x32xf32> -> tensor<128x256xf32>
+  return %0 : tensor<128x256xf32>
+}
+
+// -----
+
+func.func @unpack_invalid_negative_drop_last_tiles(%source: tensor<17x8x8x32xf32>, %dest: tensor<128x256xf32>) -> tensor<128x256xf32> {
+  // expected-error at +1 {{drop_last_tiles must contain only non-negative values}}
+  %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [-1, 0] into %dest : tensor<17x8x8x32xf32> -> tensor<128x256xf32>
+  return %0 : tensor<128x256xf32>
+}
+
+// -----
+
 func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
   %cst = arith.constant 0.0 : f32
   // expected-error at +1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
@@ -1977,6 +2019,15 @@ func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>,
 
 // -----
 
+func.func @unpack_invalid_drop_last_tiles_shape(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
+  // expected-error at +1 {{expected 'tensor<4x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
+  %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] drop_last_tiles = [2] into %output
+      : tensor<3x8xf32> -> tensor<9xf32>
+  return %0 : tensor<9xf32>
+}
+
+// -----
+
 func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
   // expected-error at +1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}}
   %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index bfb92c3289a49..ace2579566025 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -829,3 +829,16 @@ func.func @test_unpack_tensor(%arg0: tensor<16x8x8x32xf32>, %arg1: tensor<128x25
   // CHECK: return %[[RESULT]] : tensor<128x256xf32>
   return %0 : tensor<128x256xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @test_pack_unpack_tensor_with_tile_adjustments
+func.func @test_pack_unpack_tensor_with_tile_adjustments(%arg0: tensor<127x255xf32>, %arg1: tensor<17x10x8x32xf32>) -> tensor<127x255xf32> {
+  %pad = arith.constant 0.0 : f32
+  // CHECK: %[[PACK:.*]] = linalg.pack %{{.*}} padding_value(%{{.*}} : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, 2] into %{{.*}} : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+  %0 = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, 2] into %arg1 : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+  // CHECK: %[[UNPACK:.*]] = linalg.unpack %[[PACK]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [1, 2] into %{{.*}} : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+  %1 = linalg.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [1, 2] into %arg0 : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+  // CHECK: return %[[UNPACK]] : tensor<127x255xf32>
+  return %1 : tensor<127x255xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index b6fe67a9ae1f3..0fbbd812a9292 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -229,6 +229,55 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func.func @pack_with_extra_pad_tiles(
+func.func @pack_with_extra_pad_tiles(%arg0: tensor<127x255xf32>, %arg1: tensor<17x10x8x32xf32>) -> tensor<17x10x8x32xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+  // CHECK: tensor.pad {{.*}} low[0, 0] high[9, 65]
+  // CHECK:   : tensor<127x255xf32> to tensor<136x320xf32>
+  %pack = linalg.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, 2] into %arg1
+    : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+  return %pack : tensor<17x10x8x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["linalg.pack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"linalg.pack">
+    transform.structured.lower_pack %pack : (!transform.op<"linalg.pack">)
+      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+      transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_with_drop_last_tiles(
+func.func @unpack_with_drop_last_tiles(%arg0: tensor<17x10x8x32xf32>, %arg1: tensor<127x255xf32>) -> tensor<127x255xf32> {
+  // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<17x8x10x32xf32>
+  // CHECK: %[[TRAN:.*]] = linalg.transpose
+  // CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{.*}} : tensor<17x8x10x32xf32> into tensor<136x320xf32>
+  // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [127, 255] [1, 1]
+  %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [1, 2] into %arg1
+    : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+  return %unpack : tensor<127x255xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["linalg.unpack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"linalg.unpack">
+    transform.structured.lower_unpack %unpack : (!transform.op<"linalg.unpack">)
+      -> (!transform.op<"tensor.empty">,
+          !transform.op<"linalg.transpose">,
+          !transform.op<"tensor.collapse_shape">,
+          !transform.op<"tensor.extract_slice">,
+          !transform.op<"linalg.copy">)
+          transform.yield
+  }
+}
+
+// -----
+
 // When an unpack is a plain 'unpad', lower it to a simple extract_slice.
 // CHECK-LABEL: func.func @unpack_as_pad(
 func.func @unpack_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 68b32098b7782..8f3fad247a6d7 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -644,7 +644,7 @@ def testPackUnPackOp():
 
             @func.FuncOp.from_py_func(
                 RankedTensorType.get((128, 128), f32),
-                RankedTensorType.get((16, 16, 8, 8), f32),
+                RankedTensorType.get((18, 17, 8, 8), f32),
             )
             def tensor_pack(src, dst):
                 packed = linalg.pack(
@@ -652,6 +652,7 @@ def tensor_pack(src, dst):
                     dst,
                     inner_dims_pos=[1, 0],
                     inner_tiles=[8, 8],
+                    extra_pad_tiles=[1, 2],
                     padding_value=arith.constant(f32, 0.0),
                 )
 
@@ -660,6 +661,7 @@ def tensor_pack(src, dst):
                     src,
                     inner_dims_pos=[0, 1],
                     inner_tiles=[8, 8],
+                    drop_last_tiles=[2, 1],
                 )
 
                 return unpacked
@@ -679,10 +681,10 @@ def memref_pack(src, dst):
                 )
 
         # CHECK-LABEL:   func.func @tensor_pack(
-        # CHECK-SAME:      %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
+        # CHECK-SAME:      %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<18x17x8x8xf32>) -> tensor<128x128xf32> {
         # CHECK:           %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
-        # CHECK:           %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
-        # CHECK:           %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
+        # CHECK:           %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] extra_pad_tiles = [1, 2] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<18x17x8x8xf32>
+        # CHECK:           %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] drop_last_tiles = [2, 1] into %[[VAL_0]] : tensor<18x17x8x8xf32> -> tensor<128x128xf32>
         # CHECK:           return %[[VAL_4]] : tensor<128x128xf32>
         # CHECK:         }
         # CHECK-LABEL:   func.func @memref_pack(



More information about the Mlir-commits mailing list