[Mlir-commits] [mlir] [mlir][linalg] Add extra_pad_tiles to linalg.pack & unpack (PR #189049)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 13 02:42:08 PDT 2026


https://github.com/fabrizio-indirli updated https://github.com/llvm/llvm-project/pull/189049

>From c0aee44bf510503a28f6cce8242733fc4050c15d Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <fabrizio.indirli at arm.com>
Date: Wed, 11 Mar 2026 17:39:01 +0000
Subject: [PATCH] [mlir][linalg] Allow extra pad tiles in linalg.pack & unpack

- In linalg.pack, allow the result shape to contain additional
  full tiles of high-padding in the tiled dimensions.
  This is allowed only when the affected shape is known statically.
  Note that in the tile-and-fuse path, the extra full tiles are
  only fusible  when the tiled pack covers the whole affected
  source dimension in a single slice.
- Similarly, in linalg.unpack allow the source outer dimensions
  to be larger than the resulting unpacked shape: any additional
  outer extent is treated as extra full-tile high-padding and is
  discarded from the high end when reconstructing the result tensor.

Signed-off-by: Fabrizio Indirli <fabrizio.indirli at arm.com>
---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    |  16 ++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  87 +++++++-----
 .../Transforms/PackAndUnpackPatterns.cpp      |  18 +--
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  47 +++++--
 .../Dialect/Linalg/Transforms/Transforms.cpp  |   3 +-
 mlir/test/Dialect/Linalg/canonicalize.mlir    |  26 ++++
 mlir/test/Dialect/Linalg/invalid.mlir         |  44 +++---
 mlir/test/Dialect/Linalg/roundtrip.mlir       |  13 ++
 .../Dialect/Linalg/transform-lower-pack.mlir  |  53 +++++++-
 .../Linalg/vectorization/linalg-ops.mlir      |  53 ++++++++
 .../lower-to-loops-using-interface.mlir       |  82 ++++++++++++
 .../tile-and-fuse-consumer-using-slices.mlir  | 126 +++++++++++++++++-
 .../tile-and-fuse-consumer.mlir               | 126 +++++++++++++++++-
 mlir/test/python/dialects/linalg/ops.py       |   8 +-
 14 files changed, 607 insertions(+), 95 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 95383e6262f71..b7e58ae38e0b3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -164,10 +164,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
       tiles divide perfectly the corresponding outer dimension in the result
       tensor. It is UB if the tile does not perfectly divide the dimension.
     - If present, it will pad along high dimensions (high-padding) to make the
-      tile complete. Note that it is not allowed to have artificial padding that
-      is not strictly required by linalg.pack (i.e., padding past what is needed
-      to complete the last tile along each packed dimension). It is UB if extra
-      padding is requested.
+      tile complete.
     It is not possible to verify the requirements statically with dynamic
     shapes, so they are treated as UB.
 
@@ -185,7 +182,11 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Note: Only tiled dimensions can be padded.
     ```
 
-    Invalid example that has artificial padding:
+    The packed outer dimensions may be larger than the minimum shape implied by
+    `shape(source)` and `inner_tiles`, if `padding_value` is specified. 
+    Any additional outer extent is treated as extra full-tile high-padding.
+
+    Invalid example that has artificial padding without a `padding_value`:
     ```mlir
     %0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0]
         inner_tiles = [8] into %dest
@@ -329,6 +330,11 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
     operation and hence the following holds
     `NumElementsOf(source) >= NumElementsOf(result)`.
 
+    The packed source outer dimensions may be larger than the minimum shape
+    implied by `shape(result)` and `inner_tiles`. Any additional outer extent
+    is treated as extra full-tile high-padding and is discarded from the high
+    end when reconstructing the result tensor.
+
     Examples:
 
     ```mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e9698365765e7..0e3989aa60992 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5104,6 +5104,21 @@ static SmallVector<int64_t> getStaticTilesImpl(OpTy op) {
   return staticTiles;
 }
 
+static inline bool isEmptyOrZeroArray(ArrayRef<int64_t> values) {
+  return values.empty() ||
+         llvm::all_of(values, [](int64_t value) { return value == 0; });
+}
+
+static SmallVector<int64_t>
+getPackedOuterShapeInUnpackedOrder(ArrayRef<int64_t> packedShape,
+                                   size_t unpackedRank,
+                                   ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<int64_t> result(packedShape.take_front(unpackedRank));
+  if (!outerDimsPerm.empty())
+    applyPermutationToVector(result, invertPermutationVector(outerDimsPerm));
+  return result;
+}
+
 /// Returns true if `dimsPos` is invalid. It is invalid when:
 /// a) It contains duplicate.
 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
@@ -5215,20 +5230,30 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
              << " but dynamic tile size";
     }
   }
-  if (failed(
-          verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) {
-    auto elementType = unpackedType.getElementType();
-    Type expectedType, actualType;
-    if (packOrUnPack.hasPureTensorSemantics()) {
-      expectedType = RankedTensorType::get(expectedPackedShape, elementType);
-      actualType = RankedTensorType::get(packedType.getShape(), elementType);
-    } else {
-      expectedType = MemRefType::get(expectedPackedShape, elementType);
-      actualType = MemRefType::get(packedType.getShape(), elementType);
+  SmallVector<int64_t> expectedOuterShape = getPackedOuterShapeInUnpackedOrder(
+      expectedPackedShape, unpackedRank, outerDimPerm);
+  SmallVector<int64_t> actualOuterShape =
+      getPackedOuterShapeWithoutTransposition(packOrUnPack);
+  llvm::SmallBitVector areOuterDimsTiled(unpackedRank);
+  for (int64_t pos : innerDimsPos)
+    areOuterDimsTiled.set(pos);
+  for (auto [index, actualOuter] : llvm::enumerate(actualOuterShape)) {
+    int64_t expectedOuter = expectedOuterShape[index];
+    if (ShapedType::isDynamic(actualOuter) ||
+        ShapedType::isDynamic(expectedOuter))
+      continue;
+    if (areOuterDimsTiled[index]) {
+      if (actualOuter < expectedOuter) {
+        return op->emitError("expected packed outer dimension ")
+               << index << " to be at least " << expectedOuter << ", got "
+               << actualOuter;
+      }
+      continue;
+    }
+    if (actualOuter != expectedOuter) {
+      return op->emitError("expected packed outer dimension ")
+             << index << " to be " << expectedOuter << ", got " << actualOuter;
     }
-    return op->emitError("expected ")
-           << expectedType << " for the packed domain value, got "
-           << actualType;
   }
   return success();
 }
@@ -5505,28 +5530,22 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
                                  ArrayRef<int64_t> outputShape,
                                  ArrayRef<int64_t> outerDimsPerm,
                                  ArrayRef<OpFoldResult> innerTiles) {
-  SmallVector<int64_t> outputTileSizes(
-      outputShape.take_front(inputShape.size()));
+  SmallVector<int64_t> outputTileSizes = getPackedOuterShapeInUnpackedOrder(
+      outputShape, inputShape.size(), outerDimsPerm);
   if (!outerDimsPerm.empty()) {
     assert(outerDimsPerm.size() == outputTileSizes.size() &&
            "expected output and outer_dims_perm to have same size");
-    applyPermutationToVector(outputTileSizes,
-                             invertPermutationVector(outerDimsPerm));
   }
   for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
-    if (ShapedType::isDynamic(inputShape[pos]))
+    if (ShapedType::isDynamic(inputShape[pos]) ||
+        ShapedType::isDynamic(outputTileSizes[pos]))
       continue;
     std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
-    if (!constantTile) {
-      if (ShapedType::isStatic(outputTileSizes[pos]) &&
-          (inputShape[pos] % outputTileSizes[pos] != 0))
-        return true;
-    } else {
-      assert(*constantTile != 0 && "static tile size can't be zero");
-      if (inputShape[pos] % (*constantTile) != 0) {
-        return true;
-      }
-    }
+    if (!constantTile)
+      continue;
+    assert(*constantTile != 0 && "static tile size can't be zero");
+    if (outputTileSizes[pos] * (*constantTile) != inputShape[pos])
+      return true;
   }
   return false;
 }
@@ -5536,13 +5555,11 @@ bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
                                        ArrayRef<int64_t> outputShape,
                                        ArrayRef<int64_t> outerDimsPerm,
                                        ArrayRef<OpFoldResult> innerTiles) {
-  SmallVector<int64_t> outputTileSizes(
-      outputShape.take_front(inputShape.size()));
+  SmallVector<int64_t> outputTileSizes = getPackedOuterShapeInUnpackedOrder(
+      outputShape, inputShape.size(), outerDimsPerm);
   if (!outerDimsPerm.empty()) {
     assert(outerDimsPerm.size() == outputTileSizes.size() &&
            "expected output and outer_dims_perm to have same size");
-    applyPermutationToVector(outputTileSizes,
-                             invertPermutationVector(outerDimsPerm));
   }
   for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
     if (ShapedType::isDynamic(inputShape[pos]) ||
@@ -5614,7 +5631,6 @@ SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
     resultShape[tiledDim.value()] = llvm::divideCeilSigned(
         resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
   }
-
   // Swap tile loops if outer_dims_perm is available.
   if (!outerDimsPerm.empty())
     applyPermutationToVector(resultShape, outerDimsPerm);
@@ -6422,8 +6438,9 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
   SmallVector<int64_t> outerShapeWithoutTranspose =
       getPackedOuterShapeWithoutTransposition(*this);
   SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(), false);
-  for (auto [pos, tileSize] :
-       llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
+  for (auto [index, values] : llvm::enumerate(llvm::zip_equal(
+           this->getInnerDimsPos(), this->getStaticInnerTiles()))) {
+    auto [pos, tileSize] = values;
     areOuterDimsTiled[pos] = true;
     if (unpackedTypeAfterFold.isDynamicDim(pos))
       return false;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 993eae62535c3..c90981b69cd27 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -236,9 +236,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
     ShapedType unpackedType = packOp.getSourceType();
     SmallVector<int64_t> outerShapeWithoutTranspose =
         getPackedOuterShapeWithoutTransposition(packOp);
-    for (auto [pos, tileSize, high] :
-         llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
-                         padOp.getMixedHighPad())) {
+    for (auto [index, tuple] : llvm::enumerate(llvm::zip_equal(
+             packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
+             padOp.getMixedHighPad()))) {
+      auto [pos, tileSize, high] = tuple;
       if (unpackedType.isDynamicDim(pos))
         return failure();
       if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
@@ -248,10 +249,9 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
       std::optional<int64_t> cstHigh = getConstantIntValue(high);
       if (!cstHigh)
         return failure();
-      int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
-                            unpackedType.getDimSize(pos);
-      // Do not fold the op if it requires artificial padding.
-      if (paddingSize + cstHigh.value() >= tileSize)
+      int64_t originalDim = unpackedType.getDimSize(pos) - cstHigh.value();
+      int64_t baseOuter = llvm::divideCeilSigned(originalDim, tileSize);
+      if (baseOuter != outerShapeWithoutTranspose[pos])
         return failure();
     }
 
@@ -440,7 +440,7 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 
     // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
     // permutation rank won't necessarily be equal in all cases.
-    for (auto dim : innerDimsPos)
+    for (int64_t dim : innerDimsPos)
       newInnerDimsPosVec.push_back(transposePermutation[dim]);
 
     Value output = packOp.createDestinationTensor(
@@ -497,7 +497,7 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
 
     // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
     // permutation rank won't necessarily be equal in all cases.
-    for (auto dim : innerDimsPos)
+    for (int64_t dim : innerDimsPos)
       newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
 
     if (!outerDimsPerm.empty())
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 558ebdebd65c5..7981dd211ee36 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -1101,13 +1101,16 @@ struct PackOpTiling
     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
         packOp.getDimAndTileMapping();
-    SmallVector<int64_t> outerShapeWithoutTranspose(
-        packOp.getDestType().getShape().take_front(packOp.getSourceRank()));
+    SmallVector<OpFoldResult> outerShapeWithoutTranspose =
+        tensor::getMixedSizes(b, loc, packOp.getDest());
+    outerShapeWithoutTranspose.resize(packOp.getSourceRank());
     if (!packOp.getOuterDimsPerm().empty()) {
       applyPermutationToVector(
           outerShapeWithoutTranspose,
           invertPermutationVector(packOp.getOuterDimsPerm()));
     }
+    SmallVector<OpFoldResult> srcDimValues =
+        tensor::getMixedSizes(b, loc, packOp.getSource());
     for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
       if (dimAndTileMapping.count(dim)) {
         FailureOr<int64_t> cstTileSize =
@@ -1125,18 +1128,12 @@ struct PackOpTiling
         // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
         // hard check to determine if a dimension is tiled or not.
         int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
-        int64_t destDimSize = outerShapeWithoutTranspose[dim];
         bool isTiled = failed(cstTileSize) ||
                        ShapedType::isDynamic(srcDimSize) ||
                        cstTileSize.value() < srcDimSize;
         if (!isTiled) {
           outerDimOffsets.push_back(offsets[dim]);
-          if (ShapedType::isStatic(destDimSize)) {
-            outerDimSizes.push_back(b.getIndexAttr(destDimSize));
-          } else {
-            outerDimSizes.push_back(
-                b.createOrFold<tensor::DimOp>(loc, packOp.getDest(), dim));
-          }
+          outerDimSizes.push_back(outerShapeWithoutTranspose[dim]);
           continue;
         }
 
@@ -1165,9 +1162,37 @@ struct PackOpTiling
         bindSymbols(b.getContext(), sym);
         auto avOffset = AV(dim0).bind(offsets[dim]);
         auto avSize = AV(dim0).bind(sizes[dim]);
+        auto avSrcDim = AV(dim0).bind(srcDimValues[dim]);
         auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
-        outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
-        outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
+        OpFoldResult outerDimOffset = ab.floor(avOffset, avTileSize);
+        // Minimum packed outer-dimension size needed for the current tiled
+        // source slice.
+        OpFoldResult minOuterDimSize = ab.ceil(avSize, avTileSize);
+        outerDimOffsets.push_back(outerDimOffset);
+        // Minimum packed outer-dimension size needed to represent the entire
+        // source dimension, ignoring any intentional extra pad tiles.
+        OpFoldResult minGlobalOuterDimSize = ab.ceil(avSrcDim, avTileSize);
+        std::optional<int64_t> actualOuterCst =
+            getConstantIntValue(outerShapeWithoutTranspose[dim]);
+        std::optional<int64_t> minGlobalOuterCst =
+            getConstantIntValue(minGlobalOuterDimSize);
+        bool hasStaticExtraFullTiles = actualOuterCst && minGlobalOuterCst &&
+                                       *actualOuterCst > *minGlobalOuterCst;
+
+        // Dynamic packed outer dims are interpreted as the minimum packed
+        // shape.
+        if (!hasStaticExtraFullTiles) {
+          outerDimSizes.push_back(minOuterDimSize);
+          continue;
+        }
+
+        // Extra full tiles are only fusible when the tiled pack covers the
+        // whole affected source dimension in a single slice.
+        if (!isZeroInteger(offsets[dim]) ||
+            !isEqualConstantIntOrValue(sizes[dim], srcDimValues[dim]))
+          return failure();
+
+        outerDimSizes.push_back(outerShapeWithoutTranspose[dim]);
       } else {
         outerDimOffsets.push_back(offsets[dim]);
         outerDimSizes.push_back(sizes[dim]);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 260e36fb47f04..ea401435b6449 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -588,7 +588,8 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
     // Build the symmetrical UnPackOp to the existing PackOp.
     unPackOps.push_back(linalg::UnPackOp::create(
         rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
-        maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
+        maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles(),
+        maybePackedInit.getOuterDimsPerm()));
     results.push_back(unPackOps.back().getResult());
   }
 
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 0c5a1c6108ae3..efe06d0e72963 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -2158,6 +2158,32 @@ func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
 
 // -----
 
+// CHECK-LABEL: func.func @keep_unpack_pack_tensor_with_larger_packed_shape
+// CHECK: linalg.unpack
+// CHECK: linalg.pack
+func.func @keep_unpack_pack_tensor_with_larger_packed_shape(%x: tensor<17x10x8x32xf32>, %dest: tensor<127x255xf32>) -> tensor<17x10x8x32xf32> {
+  %unpacked = linalg.unpack %x inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+             into %dest : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+  %cst = arith.constant 0.0 : f32
+  %packed = linalg.pack %unpacked padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+             into %x : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+  return %packed : tensor<17x10x8x32xf32>
+}
+
+// CHECK-LABEL: func.func @do_not_fold_unpack_pack_tensor_with_mismatched_packed_shape
+// CHECK: linalg.unpack
+// CHECK: linalg.pack
+func.func @do_not_fold_unpack_pack_tensor_with_mismatched_packed_shape(%x: tensor<17x10x8x32xf32>, %dest: tensor<127x255xf32>, %packed_dest: tensor<17x9x8x32xf32>) -> tensor<17x9x8x32xf32> {
+  %unpacked = linalg.unpack %x inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+             into %dest : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+  %cst = arith.constant 0.0 : f32
+  %packed = linalg.pack %unpacked padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+             into %packed_dest : tensor<127x255xf32> -> tensor<17x9x8x32xf32>
+  return %packed : tensor<17x9x8x32xf32>
+}
+
+// -----
+
 // Test that pack/unpack canonicalization is disabled for memref versions.
 // CHECK-LABEL: func.func @negative_pack_unpack_memref_no_canonicalization
 // CHECK: linalg.pack
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index d9192cbda14e7..c81c3c6b92abf 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1811,22 +1811,6 @@ func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %o
 
 // -----
 
-func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles(%input: tensor<256x128xf32>, %output: tensor<10x8x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<10x8x?x?xf32> {
-  // expected-error at +1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
-  %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32>  -> tensor<10x8x?x?xf32>
-  return %0 : tensor<10x8x?x?xf32>
-}
-
-// -----
-
-func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles_outperm(%input: tensor<256x128xf32>, %output: tensor<8x10x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<8x10x?x?xf32> {
-  // expected-error at +1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
-  %0 = linalg.pack %input outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32>  -> tensor<8x10x?x?xf32>
-  return %0 : tensor<8x10x?x?xf32>
-}
-
-// -----
-
 func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> {
   // expected-error at +1 {{expected padding_value has 'f32' but got: 'i32'}}
   %0 = linalg.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
@@ -1911,6 +1895,14 @@ func.func @pack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: ten
 
 // -----
 
+func.func @pack_invalid_larger_result_without_padding(%source: tensor<128x256xf32>, %dest: tensor<17x8x8x32xf32>) -> tensor<17x8x8x32xf32> {
+  // expected-error at +1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
+  %0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : tensor<128x256xf32> -> tensor<17x8x8x32xf32>
+  return %0 : tensor<17x8x8x32xf32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // linalg.unpack
 //===----------------------------------------------------------------------===//
@@ -1939,13 +1931,13 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
 
 // -----
 
-func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
+func.func @pack_invalid_small_result_shape_1d(%input: tensor<9xf32>, %output: tensor<1x8xf32>) -> tensor<1x8xf32> {
   %cst = arith.constant 0.0 : f32
-  // expected-error at +1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
+  // expected-error at +1 {{expected packed outer dimension 0 to be at least 2, got 1}}
   %0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0]
       inner_tiles = [8] into %output
-      : tensor<9xf32> -> tensor<3x8xf32>
-  return %0 : tensor<3x8xf32>
+      : tensor<9xf32> -> tensor<1x8xf32>
+  return %0 : tensor<1x8xf32>
 }
 
 // -----
@@ -1953,7 +1945,7 @@ func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3
 // The outer dims in the output tensor are incorrectly/unexpectedly transposed.
 // This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
 func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
-  // expected-error at +1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}}
+  // expected-error at +1 {{expected packed outer dimension 0 to be at least 16, got 4}}
   %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
   return %0 : tensor<4x16x32x16xf32>
 }
@@ -1961,24 +1953,24 @@ func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tenso
 // -----
 
 func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
-  // expected-error at +1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}}
+  // expected-error at +1 {{expected packed outer dimension 1 to be at least 8, got 7}}
   %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
   return %0 : tensor<8x7x16x32xf32>
 }
 
 // -----
 
-func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
-  // expected-error at +1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
+func.func @unpack_invalid_small_source_shape_1d(%input: tensor<1x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
+  // expected-error at +1 {{expected packed outer dimension 0 to be at least 2, got 1}}
   %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
-      : tensor<3x8xf32> -> tensor<9xf32>
+      : tensor<1x8xf32> -> tensor<9xf32>
   return %0 : tensor<9xf32>
 }
 
 // -----
 
 func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
-  // expected-error at +1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}}
+  // expected-error at +1 {{expected packed outer dimension 1 to be at least 32, got 8}}
   %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
   return %0 : tensor<256x128xf32>
 }
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index bfb92c3289a49..b776da9a0b86f 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -829,3 +829,16 @@ func.func @test_unpack_tensor(%arg0: tensor<16x8x8x32xf32>, %arg1: tensor<128x25
   // CHECK: return %[[RESULT]] : tensor<128x256xf32>
   return %0 : tensor<128x256xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @test_pack_unpack_tensor_with_extra_tiles
+func.func @test_pack_unpack_tensor_with_extra_tiles(%arg0: tensor<127x255xf32>, %arg1: tensor<17x10x8x32xf32>) -> tensor<127x255xf32> {
+  %pad = arith.constant 0.0 : f32
+  // CHECK: %[[PACK:.*]] = linalg.pack %{{.*}} padding_value(%{{.*}} : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+  %0 = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+  // CHECK: %[[UNPACK:.*]] = linalg.unpack %[[PACK]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+  %1 = linalg.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg0 : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+  // CHECK: return %[[UNPACK]] : tensor<127x255xf32>
+  return %1 : tensor<127x255xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index b6fe67a9ae1f3..9f777c5aac425 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -229,6 +229,55 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func.func @pack_with_larger_result_shape(
+func.func @pack_with_larger_result_shape(%arg0: tensor<127x255xf32>, %arg1: tensor<17x10x8x32xf32>) -> tensor<17x10x8x32xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+  // CHECK: tensor.pad {{.*}} low[0, 0] high[9, 65]
+  // CHECK:   : tensor<127x255xf32> to tensor<136x320xf32>
+  %pack = linalg.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1
+    : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+  return %pack : tensor<17x10x8x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["linalg.pack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"linalg.pack">
+    transform.structured.lower_pack %pack : (!transform.op<"linalg.pack">)
+      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+      transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_with_larger_packed_shape(
+func.func @unpack_with_larger_packed_shape(%arg0: tensor<17x10x8x32xf32>, %arg1: tensor<127x255xf32>) -> tensor<127x255xf32> {
+  // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<17x8x10x32xf32>
+  // CHECK: %[[TRAN:.*]] = linalg.transpose
+  // CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{.*}} : tensor<17x8x10x32xf32> into tensor<136x320xf32>
+  // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [127, 255] [1, 1]
+  %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1
+    : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+  return %unpack : tensor<127x255xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["linalg.unpack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"linalg.unpack">
+    transform.structured.lower_unpack %unpack : (!transform.op<"linalg.unpack">)
+      -> (!transform.op<"tensor.empty">,
+          !transform.op<"linalg.transpose">,
+          !transform.op<"tensor.collapse_shape">,
+          !transform.op<"tensor.extract_slice">,
+          !transform.op<"linalg.copy">)
+          transform.yield
+  }
+}
+
+// -----
+
 // When an unpack is a plain 'unpad', lower it to a simple extract_slice.
 // CHECK-LABEL: func.func @unpack_as_pad(
 func.func @unpack_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
@@ -625,8 +674,8 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:   permutation = [0, 2, 1, 3]
 //      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
 // CHECK-SAME:   : tensor<?x?x?x?xf32> into tensor<?x?xf32>
-//      CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//      CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//  CHECK-DAG: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//  CHECK-DAG: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
 //      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
 // CHECK-SAME:   : tensor<?x?xf32> to tensor<?x?xf32>
 //      CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index a5d94bc4f581c..fc0d0674b299f 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1252,6 +1252,31 @@ func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>,
 
 // -----
 
+// CHECK-LABEL: test_vectorize_unpack_larger_packed_shape_no_vector_sizes
+// CHECK-SAME:      %[[SRC:.*]]: tensor<17x10x8x32xf32>
+// CHECK-SAME:      %[[DEST:.*]]: tensor<127x255xf32>
+func.func @test_vectorize_unpack_larger_packed_shape_no_vector_sizes(%source: tensor<17x10x8x32xf32>, %dest: tensor<127x255xf32>) -> tensor<127x255xf32> {
+  // CHECK-DAG: %[[PAD:.*]] = ub.poison : f32
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}} %[[PAD]] {{.*}} : tensor<17x10x8x32xf32>, vector<17x10x8x32xf32>
+  // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<17x10x8x32xf32> to vector<17x8x10x32xf32>
+  // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<17x8x10x32xf32> to vector<136x320xf32>
+  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}} : vector<136x320xf32>, tensor<127x255xf32>
+  // CHECK: return %[[WRIT]] : tensor<127x255xf32>
+  %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest
+    : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+  return %0 : tensor<127x255xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: test_vectorize_unpack_no_vector_sizes_slice_output
 // CHECK-SAME:      %[[SRC:.*]]: tensor<8x4x16x16xf32>
 // CHECK-SAME:      %[[DEST:.*]]: tensor<64x127xf32>
@@ -1421,6 +1446,34 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func @pack_with_larger_result_shape_no_vector_sizes
+// CHECK-SAME:      %[[SRC:.*]]: tensor<127x255xf32>,
+// CHECK-SAME:      %[[DEST:.*]]: tensor<17x10x8x32xf32>
+func.func @pack_with_larger_result_shape_no_vector_sizes(%src: tensor<127x255xf32>, %dest: tensor<17x10x8x32xf32>) -> tensor<17x10x8x32xf32> {
+  %pad = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %src padding_value(%pad : f32)
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, 32]
+    into %dest : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+  return %pack : tensor<17x10x8x32xf32>
+}
+//  CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//      CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] : tensor<127x255xf32>, vector<136x320xf32>
+//      CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<136x320xf32> to vector<17x8x10x32xf32>
+//      CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 2, 1, 3] : vector<17x8x10x32xf32> to vector<17x10x8x32xf32>
+//      CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {in_bounds = [true, true, true, true]} : vector<17x10x8x32xf32>, tensor<17x10x8x32xf32>
+//      CHECK: return %[[WRITE]] : tensor<17x10x8x32xf32>
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @pack_with_dynamic_dims
 // CHECK-SAME:      %[[SRC:.*]]: tensor<?x?xf32>,
 // CHECK-SAME:      %[[DEST:.*]]: tensor<?x?x16x2xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index ec9b606491910..9c792a3ea55f9 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -488,6 +488,52 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @NC_to_NCnc_extra(%arg0: memref<128x128xf32>, %arg1: memref<18x17x8x8xf32>, %arg2: f32) {
+  linalg.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8]
+      into %arg1
+    : memref<128x128xf32> -> memref<18x17x8x8xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %pack
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK:       #[[$MAP_EXTRA:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK-LABEL: func @NC_to_NCnc_extra(
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]: memref<128x128xf32>
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]: memref<18x17x8x8xf32>
+// CHECK-SAME:    %[[PAD:[a-zA-Z0-9]+]]: f32
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C17:.*]] = arith.constant 17 : index
+// CHECK-DAG:     %[[C18:.*]] = arith.constant 18 : index
+// CHECK-DAG:     %[[C128:.*]] = arith.constant 128 : index
+// CHECK:         scf.for %[[I:.*]] = %[[C0]] to %[[C18]] step %[[C1]] {
+// CHECK:           scf.for %[[J:.*]] = %[[C0]] to %[[C17]] step %[[C1]] {
+// CHECK:             scf.for %[[K:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:               scf.for %[[L:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK-DAG:             %[[SRC_I:.*]] = affine.apply #[[$MAP_EXTRA]](%[[I]], %[[L]])
+// CHECK-DAG:             %[[SRC_J:.*]] = affine.apply #[[$MAP_EXTRA]](%[[J]], %[[K]])
+// CHECK:                 %[[BOUND_I:.*]] = arith.cmpi slt, %[[SRC_I]], %[[C128]] : index
+// CHECK:                 %[[BOUND_J:.*]] = arith.cmpi slt, %[[SRC_J]], %[[C128]] : index
+// CHECK:                 %[[IN_BOUNDS:.*]] = arith.andi %[[BOUND_I]], %[[BOUND_J]] : i1
+// CHECK:                 %[[VAL:.*]] = scf.if %[[IN_BOUNDS]] -> (f32) {
+// CHECK:                   %[[LOAD:.*]] = memref.load %[[SRC]][%[[SRC_I]], %[[SRC_J]]] : memref<128x128xf32>
+// CHECK:                   scf.yield %[[LOAD]]
+// CHECK:                 } else {
+// CHECK:                   scf.yield %[[PAD]]
+// CHECK:                 }
+// CHECK:                 memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref<18x17x8x8xf32>
+
+// -----
+
 func.func @KC_to_KCck(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
   linalg.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %arg1
     : memref<128x256xf32> -> memref<4x8x32x32xf32>
@@ -564,6 +610,42 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @NCnc_to_NC_drop(%arg0: memref<128x128xf32>, %arg1: memref<18x17x8x8xf32>) {
+  linalg.unpack %arg1 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %arg0
+    : memref<18x17x8x8xf32> -> memref<128x128xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %unpack
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:   #[[$MAP_DROP_FLOOR:.*]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-DAG:   #[[$MAP_DROP_MOD:.*]] = affine_map<(d0) -> (d0 mod 8)>
+// CHECK-LABEL: func @NCnc_to_NC_drop(
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]: memref<128x128xf32>
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]: memref<18x17x8x8xf32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C128:.*]] = arith.constant 128 : index
+// CHECK:         scf.for %[[I:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK:           scf.for %[[J:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK-DAG:         %[[FLOOR_I:.*]] = affine.apply #[[$MAP_DROP_FLOOR]](%[[I]])
+// CHECK-DAG:         %[[MOD_I:.*]] = affine.apply #[[$MAP_DROP_MOD]](%[[I]])
+// CHECK-DAG:         %[[FLOOR_J:.*]] = affine.apply #[[$MAP_DROP_FLOOR]](%[[J]])
+// CHECK-DAG:         %[[MOD_J:.*]] = affine.apply #[[$MAP_DROP_MOD]](%[[J]])
+// CHECK:             %[[VAL:.*]] = memref.load %[[SRC]][%[[FLOOR_I]], %[[FLOOR_J]], %[[MOD_I]], %[[MOD_J]]] : memref<18x17x8x8xf32>
+// CHECK:             memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]]] : memref<128x128xf32>
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
 func.func @KCck_to_KC(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
   linalg.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %arg0
     : memref<4x8x32x32xf32> -> memref<128x256xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
index 62dd7faec4eb7..344e4c5a76415 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
@@ -502,6 +502,58 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+#map = affine_map<(d0) -> (-d0 + 4, 16)>
+func.func @fuse_pack_consumer_if_single_iteration_extra_pad_tiles(%arg0: tensor<4x4xf32>) -> tensor<2x5x16x1xf32> {
+  %0 = tensor.empty() : tensor<2x5x16x1xf32>
+  %1 = tensor.empty() : tensor<4x4xf32>
+  %2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) {
+    %3 = affine.min #map(%arg1)
+    %extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+    %extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+    %4 = linalg.exp ins(%extracted_slice : tensor<?x4xf32>) outs(%extracted_slice_0 : tensor<?x4xf32>) -> tensor<?x4xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<?x4xf32> into tensor<4x4xf32>
+    }
+  }
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1]
+      inner_tiles = [16, 1] into %0
+    : tensor<4x4xf32> -> tensor<2x5x16x1xf32>
+  return %pack : tensor<2x5x16x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_pack_consumer_if_single_iteration_extra_pad_tiles(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+//  CHECK-DAG:   %[[PACK_INIT:.*]] = tensor.empty() : tensor<2x5x16x1xf32>
+//  CHECK-DAG:   %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32>
+//  CHECK-DAG:   %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
+// CHECK-SAME:      shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
+//  CHECK-DAG:      %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]])
+//  CHECK-DAG:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//  CHECK-DAG:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//      CHECK:      %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME:        ins(%[[ELEM_SRC]]
+// CHECK-SAME:        outs(%[[ELEM_DEST]]
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [2, 5, 16, 1] [1, 1, 1, 1]
+//      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME:        outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1]
+// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [2, 5, 16, 1] [1, 1, 1, 1]
+
+// -----
+
 func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> {
   %0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
     %src = tensor.extract_slice %arg0[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
@@ -546,6 +598,50 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @fuse_pack_consumer_with_larger_packed_shape_on_unsliced_dim(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<64x4x16xf32>) -> tensor<64x4x16xf32> {
+  %0 = scf.forall (%arg3) = (0) to (64) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
+    %src = tensor.extract_slice %arg0[%arg3, 0] [16, 32] [1, 1] : tensor<64x32xf32> to tensor<16x32xf32>
+    %dest = tensor.extract_slice %arg4[%arg3, 0] [16, 32] [1, 1] : tensor<64x32xf32> to tensor<16x32xf32>
+    %1 = linalg.exp ins(%src : tensor<16x32xf32>) outs(%dest : tensor<16x32xf32>) -> tensor<16x32xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %1 into %arg4[%arg3, 0] [16, 32] [1, 1] : tensor<16x32xf32> into tensor<64x32xf32>
+    }
+  }
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [16] into %arg2 : tensor<64x32xf32> -> tensor<64x4x16xf32>
+  return %pack : tensor<64x4x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func.func @fuse_pack_consumer_with_larger_packed_shape_on_unsliced_dim(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (64) step (16)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
+//      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [16, 32] [1, 1]
+//      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]], 0] [16, 32] [1, 1]
+//      CHECK:      %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME:        ins(%[[ELEM_SRC]]
+// CHECK-SAME:        outs(%[[ELEM_DEST]]
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0] [16, 4, 16] [1, 1, 1]
+//      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME:        padding_value(%{{.*}} : f32)
+// CHECK-SAME:        inner_dims_pos = [1] inner_tiles = [16]
+// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][%[[IV]], 0] [16, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0] [16, 4, 16] [1, 1, 1]
+
+// -----
+
 // It is valid to fuse the pack op in perfect tiling scenario when the dimension
 // is dynamic and padding is not needed.
 
@@ -649,7 +745,7 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
 // CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
 // CHECK-SAME:        inner_dims_pos = [0, 1] inner_tiles = [3, 16]
-// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+// CHECK-SAME:        into %{{.*}}
 //      CHECK:      scf.forall.in_parallel {
 //      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]]
 // CHECK-SAME:            [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
@@ -694,6 +790,33 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @nofuse_pack_with_larger_packed_shape_on_sliced_dim(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<64x4x16xf32> {
+  %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+    %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %1 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+    scf.forall.in_parallel {
+      // expected-error @below {{failed to fuse consumer of slice}}
+      tensor.parallel_insert_slice %1 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+    }
+  }
+  %2 = tensor.empty() : tensor<64x4x16xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [16] into %2 : tensor<64x32xf32> -> tensor<64x4x16xf32>
+  return %pack : tensor<64x4x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 module {
   func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
     %c0 = arith.constant 0 : index
@@ -730,6 +853,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32>
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32>
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+//      CHECK:   %[[C0:.*]] = arith.constant 0 : index
 //      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
 //      CHECK:   %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
 // CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 0137e2a69a46e..1d3b34547b440 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -545,6 +545,58 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+#map = affine_map<(d0) -> (-d0 + 4, 16)>
+func.func @fuse_pack_consumer_if_single_iteration_extra_pad_tiles(%arg0: tensor<4x4xf32>) -> tensor<2x5x16x1xf32> {
+  %0 = tensor.empty() : tensor<2x5x16x1xf32>
+  %1 = tensor.empty() : tensor<4x4xf32>
+  %2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) {
+    %3 = affine.min #map(%arg1)
+    %extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+    %extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+    %4 = linalg.exp ins(%extracted_slice : tensor<?x4xf32>) outs(%extracted_slice_0 : tensor<?x4xf32>) -> tensor<?x4xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<?x4xf32> into tensor<4x4xf32>
+    }
+  }
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1]
+      inner_tiles = [16, 1] into %0
+    : tensor<4x4xf32> -> tensor<2x5x16x1xf32>
+  return %pack : tensor<2x5x16x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %consumer = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %fused_consumer, %new_loop = transform.test.fuse_consumer %consumer into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_pack_consumer_if_single_iteration_extra_pad_tiles(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+//  CHECK-DAG:   %[[PACK_INIT:.*]] = tensor.empty() : tensor<2x5x16x1xf32>
+//  CHECK-DAG:   %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32>
+//  CHECK-DAG:   %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
+// CHECK-SAME:      shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
+//  CHECK-DAG:      %[[SIZE:.+]] = affine.min affine_map<(d0) -> (-d0 + 4, 16)>(%[[IV]])
+//  CHECK-DAG:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//  CHECK-DAG:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//      CHECK:      %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME:        ins(%[[ELEM_SRC]]
+// CHECK-SAME:        outs(%[[ELEM_DEST]]
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [2, 5, 16, 1] [1, 1, 1, 1]
+//      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME:        outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1]
+// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [2, 5, 16, 1] [1, 1, 1, 1]
+
+// -----
+
 func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> {
   %0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
     %src = tensor.extract_slice %arg0[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
@@ -588,6 +640,50 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @fuse_pack_consumer_with_larger_packed_shape_on_unsliced_dim(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<64x4x16xf32>) -> tensor<64x4x16xf32> {
+  %0 = scf.forall (%arg3) = (0) to (64) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
+    %src = tensor.extract_slice %arg0[%arg3, 0] [16, 32] [1, 1] : tensor<64x32xf32> to tensor<16x32xf32>
+    %dest = tensor.extract_slice %arg4[%arg3, 0] [16, 32] [1, 1] : tensor<64x32xf32> to tensor<16x32xf32>
+    %1 = linalg.exp ins(%src : tensor<16x32xf32>) outs(%dest : tensor<16x32xf32>) -> tensor<16x32xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %1 into %arg4[%arg3, 0] [16, 32] [1, 1] : tensor<16x32xf32> into tensor<64x32xf32>
+    }
+  }
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [16] into %arg2 : tensor<64x32xf32> -> tensor<64x4x16xf32>
+  return %pack : tensor<64x4x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func.func @fuse_pack_consumer_with_larger_packed_shape_on_unsliced_dim(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (64) step (16)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
+//      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [16, 32] [1, 1]
+//      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]], 0] [16, 32] [1, 1]
+//      CHECK:      %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME:        ins(%[[ELEM_SRC]]
+// CHECK-SAME:        outs(%[[ELEM_DEST]]
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0] [16, 4, 16] [1, 1, 1]
+//      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME:        padding_value(%{{.*}} : f32)
+// CHECK-SAME:        inner_dims_pos = [1] inner_tiles = [16]
+// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][%[[IV]], 0] [16, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0] [16, 4, 16] [1, 1, 1]
+
+// -----
+
 // It is valid to fuse the pack op in perfect tiling scenario when the dimension
 // is dynamic and padding is not needed.
 
@@ -686,7 +782,7 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
 // CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
 // CHECK-SAME:        inner_dims_pos = [0, 1] inner_tiles = [3, 16]
-// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+// CHECK-SAME:        into %{{.*}}
 //      CHECK:      scf.forall.in_parallel {
 //      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]]
 // CHECK-SAME:            [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
@@ -732,6 +828,33 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @nofuse_pack_with_larger_packed_shape_on_sliced_dim(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<64x4x16xf32> {
+  %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+    %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %1 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %1 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+    }
+  }
+  %2 = tensor.empty() : tensor<64x4x16xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  // expected-error @below {{failed to fuse consumer of slice}}
+  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [16] into %2 : tensor<64x32xf32> -> tensor<64x4x16xf32>
+  return %pack : tensor<64x4x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 module {
   func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
     %c0 = arith.constant 0 : index
@@ -772,6 +895,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32>
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32>
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+//      CHECK:   %[[C0:.*]] = arith.constant 0 : index
 //      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
 //      CHECK:   %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
 // CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 68b32098b7782..a903409ea06e4 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -644,7 +644,7 @@ def testPackUnPackOp():
 
             @func.FuncOp.from_py_func(
                 RankedTensorType.get((128, 128), f32),
-                RankedTensorType.get((16, 16, 8, 8), f32),
+                RankedTensorType.get((18, 17, 8, 8), f32),
             )
             def tensor_pack(src, dst):
                 packed = linalg.pack(
@@ -679,10 +679,10 @@ def memref_pack(src, dst):
                 )
 
         # CHECK-LABEL:   func.func @tensor_pack(
-        # CHECK-SAME:      %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
+        # CHECK-SAME:      %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<18x17x8x8xf32>) -> tensor<128x128xf32> {
         # CHECK:           %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
-        # CHECK:           %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
-        # CHECK:           %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
+        # CHECK:           %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<18x17x8x8xf32>
+        # CHECK:           %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<18x17x8x8xf32> -> tensor<128x128xf32>
         # CHECK:           return %[[VAL_4]] : tensor<128x128xf32>
         # CHECK:         }
         # CHECK-LABEL:   func.func @memref_pack(



More information about the Mlir-commits mailing list