[Mlir-commits] [mlir] [mlir][linalg] Fix crash when folding tensor.cast into unpack using static packed shape for inner tiles (PR #188000)

Hocky Yudhiono llvmlistbot at llvm.org
Tue Mar 31 20:10:30 PDT 2026


https://github.com/hockyy updated https://github.com/llvm/llvm-project/pull/188000

>From ddf0ddfbcdecec8c66fcbaf9ef3bd2e22b6f948d Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Mon, 23 Mar 2026 17:44:17 +0800
Subject: [PATCH 1/6] [mlir][linalg] Bail out tensor.cast pack/unpack fold on
 unprovable tile sizes

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  41 +++--
 ...canonicalize-dynamic-pack-unpack-tile.mlir | 149 ++++++++++++++++++
 2 files changed, 176 insertions(+), 14 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ad2909f656eea..5b75be21e4822 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5000,8 +5000,10 @@ template SmallVector<int64_t>
 //  * a dim from newPackedTy is static, and
 //  * the corresponding size from mixedTiles is still dynamic.
 // Otherwise, the original tile size is preserved.
+// Returns failure when a dynamic tile cannot be proven to match the static
+// packed dim.
 // Note - packed-type-dim and mixed-tile-size should always match!
-static SmallVector<OpFoldResult>
+static FailureOr<SmallVector<OpFoldResult>>
 getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
                      SmallVector<OpFoldResult> mixedTiles) {
   SmallVector<OpFoldResult> newMixedTileSizes;
@@ -5015,17 +5017,21 @@ getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
       continue;
     }
 
-    // If the current result dim is static, update the dynamic mixed-size
-    // (provided the original value is dynamic).
+    // If the current result dim is static, update the dynamic mixed-size only
+    // when the original dynamic value is a known constant matching `shape`.
+    // Otherwise, bail out and let the fold fail conservatively.
     OpFoldResult tile = std::get<1>(it);
     if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
       // Already a constant
       newMixedTileSizes.push_back(tile);
     } else {
-      assert(getConstantIntValue(tile).value() == shape &&
-             "tile size and dim size don't match!");
-      newMixedTileSizes.push_back(
-          (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+      std::optional<int64_t> constTile = getConstantIntValue(tile);
+      if (constTile.has_value() && constTile.value() == shape) {
+        newMixedTileSizes.push_back(
+            rewriter.getIntegerAttr(rewriter.getIndexType(), shape));
+      } else {
+        return failure();
+      }
     }
   }
 
@@ -5995,8 +6001,11 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
         tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
 
     // Get the updated mixed-tile-sizes attribute.
-    SmallVector<OpFoldResult> newMixedTileSizes =
+    FailureOr<SmallVector<OpFoldResult>> newMixedTileSizes =
         getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
+    if (failed(newMixedTileSizes))
+      return rewriter.notifyMatchFailure(
+          op, "unable to prove dynamic tile sizes after folding tensor.cast");
 
     // Clone op.
     // TODO: Strictly speaking, discardable attributes should be _discarded_ at
@@ -6004,7 +6013,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
     // to preserve. Implement a better abstraction.
     PackOp newOp =
         PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
-                       op.getInnerDimsPos(), newMixedTileSizes,
+                       op.getInnerDimsPos(), newMixedTileSizes.value(),
                        op.getPaddingValue(), op.getOuterDimsPerm());
     newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
 
@@ -6476,16 +6485,20 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
     Value sourceTensor = newOperands[0];
 
     // Get the updated mixed-tile-sizes attribute.
-    SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
-        rewriter, sourceTensor.getType(), op.getMixedTiles());
+    FailureOr<SmallVector<OpFoldResult>> newMixedTileSizes =
+        getNewMixedTileSizes(rewriter, sourceTensor.getType(),
+                             op.getMixedTiles());
+    if (failed(newMixedTileSizes))
+      return rewriter.notifyMatchFailure(
+          op, "unable to prove dynamic tile sizes after folding tensor.cast");
 
     // Clone op.
     // TODO: Strictly speaking, discardable attributes should be _discarded_ at
     // this point. However, in practice, we use them for things that we'd like
     // to preserve. Implement a better abstraction.
-    UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
-                                      newOperands[1], op.getInnerDimsPos(),
-                                      newMixedTileSizes, op.getOuterDimsPerm());
+    UnPackOp newOp = UnPackOp::create(
+        rewriter, op.getLoc(), sourceTensor, newOperands[1],
+        op.getInnerDimsPos(), newMixedTileSizes.value(), op.getOuterDimsPerm());
     newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
 
     // Replace op.
diff --git a/mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir b/mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir
new file mode 100644
index 0000000000000..eec3e3acc93fb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir
@@ -0,0 +1,149 @@
+// RUN: mlir-opt %s --inline -canonicalize="test-convergence" -split-input-file | FileCheck %s --check-prefixes=CHECK
+
+// CHECK: func.func @dynamic_tile_arg_no_fold
+// CHECK-SAME:  %[[SRC:.+]]: tensor<1x3x8x1xi32>, %[[TILE:.+]]: index
+// CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-DAG:   %[[CAST:.+]] = tensor.cast %[[SRC]] : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+// CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
+// CHECK-SAME:    inner_dims_pos = [0, 1]
+// CHECK-SAME:    inner_tiles = [%[[TILE]], 1]
+// CHECK-SAME:    into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
+module {
+  func.func @dynamic_tile_arg_no_fold(%arg0: tensor<1x3x8x1xi32>, %arg1: index) -> tensor<7x3xi32> {
+    %0 = tensor.empty() : tensor<7x3xi32>
+    %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+    %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%arg1, 1] into %0 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+    return %unpack : tensor<7x3xi32>
+  }
+}
+
+
+// -----
+
+// CHECK-LABEL: func.func @dynamic_tile_from_inlined_mismatch_no_fold
+// CHECK-DAG:   %[[C256:.+]] = arith.constant 256 : index
+// CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-DAG:   %[[CAST:.+]] = tensor.cast %{{.+}} : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+// CHECK:       %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
+// CHECK-SAME:    inner_dims_pos = [0, 1]
+// CHECK-SAME:    inner_tiles = [%[[C256]], 1]
+// CHECK-SAME:    into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
+module {
+  func.func @get_tile() -> index {
+    %c256 = arith.constant 256 : index
+    return %c256 : index
+  }
+  func.func @dynamic_tile_from_inlined_mismatch_no_fold(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
+    %0 = call @get_tile() : () -> index
+    %1 = tensor.empty() : tensor<7x3xi32>
+    %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+    %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%0, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+    return %unpack : tensor<7x3xi32>
+  }
+}
+
+
+// -----
+
+// CHECK-LABEL: func.func @constant_tile_from_inlined_match_folds
+// CHECK:       %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-NOT:   tensor.cast
+// CHECK:       %[[UNPACK:.+]] = linalg.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 1]
+// CHECK-SAME:    into %[[EMPTY]] : tensor<1x3x8x1xi32> -> tensor<7x3xi32>
+// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
+module {
+  func.func @get_tile() -> index {
+    %c8 = arith.constant 8 : index
+    return %c8 : index
+  }
+  func.func @constant_tile_from_inlined_match_folds(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
+    %0 = call @get_tile() : () -> index
+    %1 = tensor.empty() : tensor<7x3xi32>
+    %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+    %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%0, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+    return %unpack : tensor<7x3xi32>
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_dynamic_tile_arg
+// CHECK-SAME:  %[[SRC:.+]]: tensor<8x3xi32>, %[[TILE:.+]]: index, %[[DEST:.+]]: tensor<?x3x?x1xi32>
+// CHECK:       %[[PACK:.+]] = linalg.pack
+// CHECK:         padding_value
+// CHECK:         inner_dims_pos = [0, 1]
+// CHECK:         inner_tiles = [%[[TILE]], 1]
+// CHECK:         into %[[DEST]] : tensor
+// CHECK:       return %[[PACK]] : tensor<?x3x?x1xi32>
+module {
+  func.func @pack_dynamic_tile_arg(%arg0: tensor<8x3xi32>, %arg1: index,
+      %dest: tensor<?x3x?x1xi32>) -> tensor<?x3x?x1xi32> {
+    %c0 = arith.constant 0 : i32
+    %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor<?x?xi32>
+    %pack = linalg.pack %cast
+      padding_value(%c0 : i32)
+      inner_dims_pos = [0, 1]
+      inner_tiles = [%arg1, 1]
+      into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
+    return %pack : tensor<?x3x?x1xi32>
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_dynamic_tile_from_inlined_mismatch
+// CHECK-DAG:   %[[C256:.+]] = arith.constant 256 : index
+// CHECK:       %[[PACK:.+]] = linalg.pack
+// CHECK:         padding_value
+// CHECK:         inner_dims_pos = [0, 1]
+// CHECK:         inner_tiles = [%[[C256]], 1]
+// CHECK:         into %{{.+}} : tensor
+// CHECK:       return %[[PACK]] : tensor<?x3x?x1xi32>
+module {
+  func.func @pack_get_tile() -> index {
+    %c256 = arith.constant 256 : index
+    return %c256 : index
+  }
+  func.func @pack_dynamic_tile_from_inlined_mismatch(%arg0: tensor<8x3xi32>,
+      %dest: tensor<?x3x?x1xi32>) -> tensor<?x3x?x1xi32> {
+    %c0 = arith.constant 0 : i32
+    %0 = call @pack_get_tile() : () -> index
+    %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor<?x?xi32>
+    %pack = linalg.pack %cast
+      padding_value(%c0 : i32)
+      inner_dims_pos = [0, 1]
+      inner_tiles = [%0, 1]
+      into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
+    return %pack : tensor<?x3x?x1xi32>
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_dynamic_tile_from_inlined_match_fold
+// CHECK:       %[[PACK:.+]] = linalg.pack
+// CHECK:         padding_value
+// CHECK:         inner_dims_pos = [0, 1]
+// CHECK:         inner_tiles = [%{{.+}}, 1]
+// CHECK:         into %{{.+}} : tensor
+// CHECK:       return %[[PACK]] : tensor<?x3x?x1xi32>
+module {
+  func.func @pack_get_tile() -> index {
+    %c8 = arith.constant 8 : index
+    return %c8 : index
+  }
+  func.func @pack_dynamic_tile_from_inlined_match_fold(%arg0: tensor<8x3xi32>,
+      %dest: tensor<?x3x?x1xi32>) -> tensor<?x3x?x1xi32> {
+    %c0 = arith.constant 0 : i32
+    %0 = call @pack_get_tile() : () -> index
+    %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor<?x?xi32>
+    %pack = linalg.pack %cast
+      padding_value(%c0 : i32)
+      inner_dims_pos = [0, 1]
+      inner_tiles = [%0, 1]
+      into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
+    return %pack : tensor<?x3x?x1xi32>
+  }
+}
\ No newline at end of file

>From 88092ddeba16130b20d2f3e87e602153e204af1f Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Mon, 23 Mar 2026 18:11:14 +0800
Subject: [PATCH 2/6] Update mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Co-authored-by: Renato Golin <rengolin at systemcall.eu>
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5b75be21e4822..effc0e3232b4f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -6001,7 +6001,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
         tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
 
     // Get the updated mixed-tile-sizes attribute.
-    FailureOr<SmallVector<OpFoldResult>> newMixedTileSizes =
+    auto newMixedTileSizes =
         getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
     if (failed(newMixedTileSizes))
       return rewriter.notifyMatchFailure(

>From c27d25d41ca28cb34bd41126bb950c4056eeb947 Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Mon, 23 Mar 2026 18:11:27 +0800
Subject: [PATCH 3/6] Update mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Co-authored-by: Renato Golin <rengolin at systemcall.eu>
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index effc0e3232b4f..918ba1284e043 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -6485,7 +6485,7 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
     Value sourceTensor = newOperands[0];
 
     // Get the updated mixed-tile-sizes attribute.
-    FailureOr<SmallVector<OpFoldResult>> newMixedTileSizes =
+    auto newMixedTileSizes =
         getNewMixedTileSizes(rewriter, sourceTensor.getType(),
                              op.getMixedTiles());
     if (failed(newMixedTileSizes))

>From d5c9b284eb84541edb5f631ab0d4f335a1ae8787 Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Mon, 23 Mar 2026 18:14:29 +0800
Subject: [PATCH 4/6] [mlir][linalg] Fix refactoring reviews for
 getNewMixedTileSizes

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 19 ++++++-------------
 1 file changed, 6 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 918ba1284e043..9d2afdbb8e6bf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5021,17 +5021,11 @@ getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
     // when the original dynamic value is a known constant matching `shape`.
     // Otherwise, bail out and let the fold fail conservatively.
     OpFoldResult tile = std::get<1>(it);
-    if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
-      // Already a constant
-      newMixedTileSizes.push_back(tile);
+    std::optional<int64_t> constTile = getConstantIntValue(tile);
+    if (constTile.has_value() && constTile.value() == shape) {
+      newMixedTileSizes.push_back(rewriter.getIndexAttr(shape));
     } else {
-      std::optional<int64_t> constTile = getConstantIntValue(tile);
-      if (constTile.has_value() && constTile.value() == shape) {
-        newMixedTileSizes.push_back(
-            rewriter.getIntegerAttr(rewriter.getIndexType(), shape));
-      } else {
-        return failure();
-      }
+      return failure();
     }
   }
 
@@ -6485,9 +6479,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
     Value sourceTensor = newOperands[0];
 
     // Get the updated mixed-tile-sizes attribute.
-    auto newMixedTileSizes =
-        getNewMixedTileSizes(rewriter, sourceTensor.getType(),
-                             op.getMixedTiles());
+    auto newMixedTileSizes = getNewMixedTileSizes(
+        rewriter, sourceTensor.getType(), op.getMixedTiles());
     if (failed(newMixedTileSizes))
       return rewriter.notifyMatchFailure(
           op, "unable to prove dynamic tile sizes after folding tensor.cast");

>From cc2065c38e3431f96eb8c83d6c62fb8b3ee025d7 Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Thu, 26 Mar 2026 09:54:13 +0800
Subject: [PATCH 5/6] [mlir][linalg] Refactor test cases

---
 ...canonicalize-dynamic-pack-unpack-tile.mlir | 149 ------------------
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 118 ++++++++++++++
 2 files changed, 118 insertions(+), 149 deletions(-)
 delete mode 100644 mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir

diff --git a/mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir b/mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir
deleted file mode 100644
index eec3e3acc93fb..0000000000000
--- a/mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir
+++ /dev/null
@@ -1,149 +0,0 @@
-// RUN: mlir-opt %s --inline -canonicalize="test-convergence" -split-input-file | FileCheck %s --check-prefixes=CHECK
-
-// CHECK: func.func @dynamic_tile_arg_no_fold
-// CHECK-SAME:  %[[SRC:.+]]: tensor<1x3x8x1xi32>, %[[TILE:.+]]: index
-// CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
-// CHECK-DAG:   %[[CAST:.+]] = tensor.cast %[[SRC]] : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
-// CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
-// CHECK-SAME:    inner_dims_pos = [0, 1]
-// CHECK-SAME:    inner_tiles = [%[[TILE]], 1]
-// CHECK-SAME:    into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
-// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
-module {
-  func.func @dynamic_tile_arg_no_fold(%arg0: tensor<1x3x8x1xi32>, %arg1: index) -> tensor<7x3xi32> {
-    %0 = tensor.empty() : tensor<7x3xi32>
-    %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
-    %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%arg1, 1] into %0 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
-    return %unpack : tensor<7x3xi32>
-  }
-}
-
-
-// -----
-
-// CHECK-LABEL: func.func @dynamic_tile_from_inlined_mismatch_no_fold
-// CHECK-DAG:   %[[C256:.+]] = arith.constant 256 : index
-// CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
-// CHECK-DAG:   %[[CAST:.+]] = tensor.cast %{{.+}} : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
-// CHECK:       %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
-// CHECK-SAME:    inner_dims_pos = [0, 1]
-// CHECK-SAME:    inner_tiles = [%[[C256]], 1]
-// CHECK-SAME:    into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
-// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
-module {
-  func.func @get_tile() -> index {
-    %c256 = arith.constant 256 : index
-    return %c256 : index
-  }
-  func.func @dynamic_tile_from_inlined_mismatch_no_fold(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
-    %0 = call @get_tile() : () -> index
-    %1 = tensor.empty() : tensor<7x3xi32>
-    %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
-    %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%0, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
-    return %unpack : tensor<7x3xi32>
-  }
-}
-
-
-// -----
-
-// CHECK-LABEL: func.func @constant_tile_from_inlined_match_folds
-// CHECK:       %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
-// CHECK-NOT:   tensor.cast
-// CHECK:       %[[UNPACK:.+]] = linalg.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 1]
-// CHECK-SAME:    into %[[EMPTY]] : tensor<1x3x8x1xi32> -> tensor<7x3xi32>
-// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
-module {
-  func.func @get_tile() -> index {
-    %c8 = arith.constant 8 : index
-    return %c8 : index
-  }
-  func.func @constant_tile_from_inlined_match_folds(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
-    %0 = call @get_tile() : () -> index
-    %1 = tensor.empty() : tensor<7x3xi32>
-    %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
-    %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%0, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
-    return %unpack : tensor<7x3xi32>
-  }
-}
-
-// -----
-
-// CHECK-LABEL: func.func @pack_dynamic_tile_arg
-// CHECK-SAME:  %[[SRC:.+]]: tensor<8x3xi32>, %[[TILE:.+]]: index, %[[DEST:.+]]: tensor<?x3x?x1xi32>
-// CHECK:       %[[PACK:.+]] = linalg.pack
-// CHECK:         padding_value
-// CHECK:         inner_dims_pos = [0, 1]
-// CHECK:         inner_tiles = [%[[TILE]], 1]
-// CHECK:         into %[[DEST]] : tensor
-// CHECK:       return %[[PACK]] : tensor<?x3x?x1xi32>
-module {
-  func.func @pack_dynamic_tile_arg(%arg0: tensor<8x3xi32>, %arg1: index,
-      %dest: tensor<?x3x?x1xi32>) -> tensor<?x3x?x1xi32> {
-    %c0 = arith.constant 0 : i32
-    %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor<?x?xi32>
-    %pack = linalg.pack %cast
-      padding_value(%c0 : i32)
-      inner_dims_pos = [0, 1]
-      inner_tiles = [%arg1, 1]
-      into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
-    return %pack : tensor<?x3x?x1xi32>
-  }
-}
-
-// -----
-
-// CHECK-LABEL: func.func @pack_dynamic_tile_from_inlined_mismatch
-// CHECK-DAG:   %[[C256:.+]] = arith.constant 256 : index
-// CHECK:       %[[PACK:.+]] = linalg.pack
-// CHECK:         padding_value
-// CHECK:         inner_dims_pos = [0, 1]
-// CHECK:         inner_tiles = [%[[C256]], 1]
-// CHECK:         into %{{.+}} : tensor
-// CHECK:       return %[[PACK]] : tensor<?x3x?x1xi32>
-module {
-  func.func @pack_get_tile() -> index {
-    %c256 = arith.constant 256 : index
-    return %c256 : index
-  }
-  func.func @pack_dynamic_tile_from_inlined_mismatch(%arg0: tensor<8x3xi32>,
-      %dest: tensor<?x3x?x1xi32>) -> tensor<?x3x?x1xi32> {
-    %c0 = arith.constant 0 : i32
-    %0 = call @pack_get_tile() : () -> index
-    %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor<?x?xi32>
-    %pack = linalg.pack %cast
-      padding_value(%c0 : i32)
-      inner_dims_pos = [0, 1]
-      inner_tiles = [%0, 1]
-      into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
-    return %pack : tensor<?x3x?x1xi32>
-  }
-}
-
-// -----
-
-// CHECK-LABEL: func.func @pack_dynamic_tile_from_inlined_match_fold
-// CHECK:       %[[PACK:.+]] = linalg.pack
-// CHECK:         padding_value
-// CHECK:         inner_dims_pos = [0, 1]
-// CHECK:         inner_tiles = [%{{.+}}, 1]
-// CHECK:         into %{{.+}} : tensor
-// CHECK:       return %[[PACK]] : tensor<?x3x?x1xi32>
-module {
-  func.func @pack_get_tile() -> index {
-    %c8 = arith.constant 8 : index
-    return %c8 : index
-  }
-  func.func @pack_dynamic_tile_from_inlined_match_fold(%arg0: tensor<8x3xi32>,
-      %dest: tensor<?x3x?x1xi32>) -> tensor<?x3x?x1xi32> {
-    %c0 = arith.constant 0 : i32
-    %0 = call @pack_get_tile() : () -> index
-    %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor<?x?xi32>
-    %pack = linalg.pack %cast
-      padding_value(%c0 : i32)
-      inner_dims_pos = [0, 1]
-      inner_tiles = [%0, 1]
-      into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
-    return %pack : tensor<?x3x?x1xi32>
-  }
-}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 77c1c3da17166..f0949fb6b1839 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -2160,3 +2160,121 @@ func.func @unpack_pack_memref_no_canonicalization(%packed: memref<16x8x8x32xf32>
   linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<128x256xf32> -> memref<16x8x8x32xf32>
   return
 }
+
+// -----
+// CHECK-LABEL: func.func @no_fold_unpack_cast_inner_tile_dynamic_arg
+// CHECK-SAME:  %[[SRC:.+]]: tensor<1x3x8x1xi32>, %[[TILE:.+]]: index
+// CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-DAG:   %[[CAST:.+]] = tensor.cast %[[SRC]] : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+// CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
+// CHECK-SAME:    inner_dims_pos = [0, 1]
+// CHECK-SAME:    inner_tiles = [%[[TILE]], 1]
+// CHECK-SAME:    into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
+func.func @no_fold_unpack_cast_inner_tile_dynamic_arg(%arg0: tensor<1x3x8x1xi32>, %arg1: index) -> tensor<7x3xi32> {
+  %0 = tensor.empty() : tensor<7x3xi32>
+  %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+  %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%arg1, 1] into %0 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+  return %unpack : tensor<7x3xi32>
+}
+
+
+// -----
+// CHECK-LABEL: func.func @no_fold_unpack_cast_inner_tile_inlined_mismatch
+// CHECK-DAG:   %[[C256:.+]] = arith.constant 256 : index
+// CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-DAG:   %[[CAST:.+]] = tensor.cast %{{.+}} : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+// CHECK:       %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
+// CHECK-SAME:    inner_dims_pos = [0, 1]
+// CHECK-SAME:    inner_tiles = [%[[C256]], 1]
+// CHECK-SAME:    into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
+func.func @no_fold_unpack_cast_inner_tile_inlined_mismatch(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
+  %c256 = arith.constant 256 : index
+  %1 = tensor.empty() : tensor<7x3xi32>
+  %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+  %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c256, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+  return %unpack : tensor<7x3xi32>
+}
+// -----
+
+// CHECK-LABEL: func.func @unpack_cast_inner_tile_inlined_match_fold
+// CHECK:       %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-NOT:   tensor.cast
+// CHECK:       %[[UNPACK:.+]] = linalg.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 1]
+// CHECK-SAME:    into %[[EMPTY]] : tensor<1x3x8x1xi32> -> tensor<7x3xi32>
+// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
+func.func @unpack_cast_inner_tile_inlined_match_fold(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
+  %c8 = arith.constant 8 : index
+  %1 = tensor.empty() : tensor<7x3xi32>
+  %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+  %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+  return %unpack : tensor<7x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @no_fold_pack_cast_inner_tile_dynamic_arg
+// CHECK-SAME:  %[[SRC:.+]]: tensor<8x3xi32>, %[[TILE:.+]]: index, %[[DEST:.+]]: tensor<?x3x?x1xi32>
+// CHECK:       %[[PACK:.+]] = linalg.pack
+// CHECK:         padding_value
+// CHECK:         inner_dims_pos = [0, 1]
+// CHECK:         inner_tiles = [%[[TILE]], 1]
+// CHECK:         into %[[DEST]] : tensor
+// CHECK:       return %[[PACK]] : tensor<?x3x?x1xi32>
+func.func @no_fold_pack_cast_inner_tile_dynamic_arg(%arg0: tensor<8x3xi32>, %arg1: index,
+    %dest: tensor<?x3x?x1xi32>) -> tensor<?x3x?x1xi32> {
+  %c0 = arith.constant 0 : i32
+  %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor<?x?xi32>
+  %pack = linalg.pack %cast
+    padding_value(%c0 : i32)
+    inner_dims_pos = [0, 1]
+    inner_tiles = [%arg1, 1]
+    into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
+  return %pack : tensor<?x3x?x1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @no_fold_pack_cast_inner_tile_inlined_mismatch
+// CHECK-DAG:   %[[C256:.+]] = arith.constant 256 : index
+// CHECK:       %[[PACK:.+]] = linalg.pack
+// CHECK:         padding_value
+// CHECK:         inner_dims_pos = [0, 1]
+// CHECK:         inner_tiles = [%[[C256]], 1]
+// CHECK:         into %{{.+}} : tensor
+// CHECK:       return %[[PACK]] : tensor<?x3x?x1xi32>
+func.func @no_fold_pack_cast_inner_tile_inlined_mismatch(%arg0: tensor<8x3xi32>,
+    %dest: tensor<?x3x?x1xi32>) -> tensor<?x3x?x1xi32> {
+  %c0 = arith.constant 0 : i32
+  %c256 = arith.constant 256 : index
+  %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor<?x?xi32>
+  %pack = linalg.pack %cast
+    padding_value(%c0 : i32)
+    inner_dims_pos = [0, 1]
+    inner_tiles = [%c256, 1]
+    into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
+  return %pack : tensor<?x3x?x1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_cast_inner_tile_inlined_match_fold
+// CHECK:       %[[PACK:.+]] = linalg.pack
+// CHECK:         padding_value
+// CHECK:         inner_dims_pos = [0, 1]
+// CHECK:         inner_tiles = [%{{.+}}, 1]
+// CHECK:         into %{{.+}} : tensor
+// CHECK:       return %[[PACK]] : tensor<?x3x?x1xi32>
+func.func @pack_cast_inner_tile_inlined_match_fold(%arg0: tensor<8x3xi32>,
+    %dest: tensor<?x3x?x1xi32>) -> tensor<?x3x?x1xi32> {
+  %c0 = arith.constant 0 : i32
+  %c8 = arith.constant 8 : index
+  %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor<?x?xi32>
+  %pack = linalg.pack %cast
+    padding_value(%c0 : i32)
+    inner_dims_pos = [0, 1]
+    inner_tiles = [%c8, 1]
+    into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
+  return %pack : tensor<?x3x?x1xi32>
+}

>From 4fdd31f8b140efbd789989460583812894fa044e Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Thu, 26 Mar 2026 10:19:56 +0800
Subject: [PATCH 6/6] [mlir][linalg] Fix testcases

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   | 41 ++++----------
 mlir/test/Dialect/Linalg/canonicalize.mlir | 63 +++++-----------------
 2 files changed, 23 insertions(+), 81 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9d2afdbb8e6bf..cbf1fa5a32502 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4995,17 +4995,14 @@ template SmallVector<int64_t>
     getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);
 
 // Given the (potentially) updated packed type, `newPackedTy`, generates an
-// updated mixed-tile-sizes attribute. A tile size is updated only
-// when:
-//  * a dim from newPackedTy is static, and
-//  * the corresponding size from mixedTiles is still dynamic.
-// Otherwise, the original tile size is preserved.
-// Returns failure when a dynamic tile cannot be proven to match the static
-// packed dim.
+// updated mixed-tile-sizes list. For each inner packed dimension that is static
+// in `newPackedTy`, the tile is set to that static size (replacing SSA values
+// or mismatched constants). Dynamic packed dimensions preserve the original
+// tile. The folded tensor type is treated as authoritative for static extents.
 // Note - packed-type-dim and mixed-tile-size should always match!
-static FailureOr<SmallVector<OpFoldResult>>
+static SmallVector<OpFoldResult>
 getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
-                     SmallVector<OpFoldResult> mixedTiles) {
+                     ArrayRef<OpFoldResult> mixedTiles) {
   SmallVector<OpFoldResult> newMixedTileSizes;
   for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
                                .getShape()
@@ -5016,17 +5013,7 @@ getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
       newMixedTileSizes.push_back(std::get<1>(it));
       continue;
     }
-
-    // If the current result dim is static, update the dynamic mixed-size only
-    // when the original dynamic value is a known constant matching `shape`.
-    // Otherwise, bail out and let the fold fail conservatively.
-    OpFoldResult tile = std::get<1>(it);
-    std::optional<int64_t> constTile = getConstantIntValue(tile);
-    if (constTile.has_value() && constTile.value() == shape) {
-      newMixedTileSizes.push_back(rewriter.getIndexAttr(shape));
-    } else {
-      return failure();
-    }
+    newMixedTileSizes.push_back(rewriter.getIndexAttr(shape));
   }
 
   return newMixedTileSizes;
@@ -5995,11 +5982,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
         tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
 
     // Get the updated mixed-tile-sizes attribute.
-    auto newMixedTileSizes =
+    SmallVector<OpFoldResult> newMixedTileSizes =
         getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
-    if (failed(newMixedTileSizes))
-      return rewriter.notifyMatchFailure(
-          op, "unable to prove dynamic tile sizes after folding tensor.cast");
 
     // Clone op.
     // TODO: Strictly speaking, discardable attributes should be _discarded_ at
@@ -6007,7 +5991,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
     // to preserve. Implement a better abstraction.
     PackOp newOp =
         PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
-                       op.getInnerDimsPos(), newMixedTileSizes.value(),
+                       op.getInnerDimsPos(), newMixedTileSizes,
                        op.getPaddingValue(), op.getOuterDimsPerm());
     newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
 
@@ -6479,11 +6463,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
     Value sourceTensor = newOperands[0];
 
     // Get the updated mixed-tile-sizes attribute.
-    auto newMixedTileSizes = getNewMixedTileSizes(
+    SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
         rewriter, sourceTensor.getType(), op.getMixedTiles());
-    if (failed(newMixedTileSizes))
-      return rewriter.notifyMatchFailure(
-          op, "unable to prove dynamic tile sizes after folding tensor.cast");
 
     // Clone op.
     // TODO: Strictly speaking, discardable attributes should be _discarded_ at
@@ -6491,7 +6472,7 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
     // to preserve. Implement a better abstraction.
     UnPackOp newOp = UnPackOp::create(
         rewriter, op.getLoc(), sourceTensor, newOperands[1],
-        op.getInnerDimsPos(), newMixedTileSizes.value(), op.getOuterDimsPerm());
+        op.getInnerDimsPos(), newMixedTileSizes, op.getOuterDimsPerm());
     newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
 
     // Replace op.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index f0949fb6b1839..285bdc21fbd1a 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -2162,16 +2162,15 @@ func.func @unpack_pack_memref_no_canonicalization(%packed: memref<16x8x8x32xf32>
 }
 
 // -----
-// CHECK-LABEL: func.func @no_fold_unpack_cast_inner_tile_dynamic_arg
+// CHECK-LABEL: func.func @fold_unpack_cast_inner_tile_dynamic_arg
 // CHECK-SAME:  %[[SRC:.+]]: tensor<1x3x8x1xi32>, %[[TILE:.+]]: index
-// CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
-// CHECK-DAG:   %[[CAST:.+]] = tensor.cast %[[SRC]] : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
-// CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
+// CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK:       %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
 // CHECK-SAME:    inner_dims_pos = [0, 1]
-// CHECK-SAME:    inner_tiles = [%[[TILE]], 1]
-// CHECK-SAME:    into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+// CHECK-SAME:    inner_tiles = [8, 1]
+// CHECK-SAME:    into %[[EMPTY]] : tensor<1x3x8x1xi32> -> tensor<7x3xi32>
 // CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
-func.func @no_fold_unpack_cast_inner_tile_dynamic_arg(%arg0: tensor<1x3x8x1xi32>, %arg1: index) -> tensor<7x3xi32> {
+func.func @fold_unpack_cast_inner_tile_dynamic_arg(%arg0: tensor<1x3x8x1xi32>, %arg1: index) -> tensor<7x3xi32> {
   %0 = tensor.empty() : tensor<7x3xi32>
   %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
   %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%arg1, 1] into %0 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
@@ -2180,35 +2179,19 @@ func.func @no_fold_unpack_cast_inner_tile_dynamic_arg(%arg0: tensor<1x3x8x1xi32>
 
 
 // -----
-// CHECK-LABEL: func.func @no_fold_unpack_cast_inner_tile_inlined_mismatch
-// CHECK-DAG:   %[[C256:.+]] = arith.constant 256 : index
-// CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
-// CHECK-DAG:   %[[CAST:.+]] = tensor.cast %{{.+}} : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
-// CHECK:       %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
-// CHECK-SAME:    inner_dims_pos = [0, 1]
-// CHECK-SAME:    inner_tiles = [%[[C256]], 1]
-// CHECK-SAME:    into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
-// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
-func.func @no_fold_unpack_cast_inner_tile_inlined_mismatch(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
-  %c256 = arith.constant 256 : index
-  %1 = tensor.empty() : tensor<7x3xi32>
-  %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
-  %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c256, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
-  return %unpack : tensor<7x3xi32>
-}
-// -----
-
-// CHECK-LABEL: func.func @unpack_cast_inner_tile_inlined_match_fold
+// Mismatched constant tile vs static packed shape: fold still drops the cast and
+// takes inner tile sizes from the refined packed type.
+// CHECK-LABEL: func.func @fold_unpack_cast_inner_tile_inlined_mismatch
 // CHECK:       %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
 // CHECK-NOT:   tensor.cast
 // CHECK:       %[[UNPACK:.+]] = linalg.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 1]
 // CHECK-SAME:    into %[[EMPTY]] : tensor<1x3x8x1xi32> -> tensor<7x3xi32>
 // CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
-func.func @unpack_cast_inner_tile_inlined_match_fold(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
-  %c8 = arith.constant 8 : index
+func.func @fold_unpack_cast_inner_tile_inlined_mismatch(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
+  %c256 = arith.constant 256 : index
   %1 = tensor.empty() : tensor<7x3xi32>
   %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
-  %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+  %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c256, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
   return %unpack : tensor<7x3xi32>
 }
 
@@ -2256,25 +2239,3 @@ func.func @no_fold_pack_cast_inner_tile_inlined_mismatch(%arg0: tensor<8x3xi32>,
     into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
   return %pack : tensor<?x3x?x1xi32>
 }
-
-// -----
-
-// CHECK-LABEL: func.func @pack_cast_inner_tile_inlined_match_fold
-// CHECK:       %[[PACK:.+]] = linalg.pack
-// CHECK:         padding_value
-// CHECK:         inner_dims_pos = [0, 1]
-// CHECK:         inner_tiles = [%{{.+}}, 1]
-// CHECK:         into %{{.+}} : tensor
-// CHECK:       return %[[PACK]] : tensor<?x3x?x1xi32>
-func.func @pack_cast_inner_tile_inlined_match_fold(%arg0: tensor<8x3xi32>,
-    %dest: tensor<?x3x?x1xi32>) -> tensor<?x3x?x1xi32> {
-  %c0 = arith.constant 0 : i32
-  %c8 = arith.constant 8 : index
-  %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor<?x?xi32>
-  %pack = linalg.pack %cast
-    padding_value(%c0 : i32)
-    inner_dims_pos = [0, 1]
-    inner_tiles = [%c8, 1]
-    into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
-  return %pack : tensor<?x3x?x1xi32>
-}



More information about the Mlir-commits mailing list