[Mlir-commits] [mlir] [mlir][linalg] Bail out tensor.cast pack/unpack fold on unprovable tile sizes (PR #188000)
Hocky Yudhiono
llvmlistbot at llvm.org
Mon Mar 23 02:45:01 PDT 2026
https://github.com/hockyy updated https://github.com/llvm/llvm-project/pull/188000
>From ab62f02f07373b19866a772ec2985f7b74427aa8 Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Mon, 23 Mar 2026 17:44:17 +0800
Subject: [PATCH] [mlir][linalg] Bail out tensor.cast pack/unpack fold on
unprovable tile sizes
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 41 +++--
...canonicalize-dynamic-pack-unpack-tile.mlir | 149 ++++++++++++++++++
2 files changed, 176 insertions(+), 14 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ad2909f656eea..95aeb821c51d0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5000,8 +5000,10 @@ template SmallVector<int64_t>
// * a dim from newPackedTy is static, and
// * the corresponding size from mixedTiles is still dynamic.
// Otherwise, the original tile size is preserved.
+// Returns failure when a dynamic tile cannot be proven to match the static
+// packed dim.
// Note - packed-type-dim and mixed-tile-size should always match!
-static SmallVector<OpFoldResult>
+static FailureOr<SmallVector<OpFoldResult>>
getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
SmallVector<OpFoldResult> mixedTiles) {
SmallVector<OpFoldResult> newMixedTileSizes;
@@ -5015,17 +5017,21 @@ getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
continue;
}
- // If the current result dim is static, update the dynamic mixed-size
- // (provided the original value is dynamic).
+ // If the current result dim is static, update the dynamic mixed-size only
+ // when the original dynamic value is a known constant matching `shape`.
+ // Otherwise, bail out and let the fold fail conservatively.
OpFoldResult tile = std::get<1>(it);
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
// Already a constant
newMixedTileSizes.push_back(tile);
} else {
- assert(getConstantIntValue(tile).value() == shape &&
- "tile size and dim size don't match!");
- newMixedTileSizes.push_back(
- (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+ std::optional<int64_t> constTile = getConstantIntValue(tile);
+ if (constTile.has_value() && constTile.value() == shape) {
+ newMixedTileSizes.push_back(
+ rewriter.getIntegerAttr(rewriter.getIndexType(), shape));
+ } else {
+ return failure();
+ }
}
}
@@ -5995,8 +6001,11 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
// Get the updated mixed-tile-sizes attribute.
- SmallVector<OpFoldResult> newMixedTileSizes =
+ FailureOr<SmallVector<OpFoldResult>> newMixedTileSizes =
getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
+ if (failed(newMixedTileSizes))
+ return rewriter.notifyMatchFailure(
+ op, "unable to prove dynamic tile sizes after folding tensor.cast");
// Clone op.
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
@@ -6004,7 +6013,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
// to preserve. Implement a better abstraction.
PackOp newOp =
PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
- op.getInnerDimsPos(), newMixedTileSizes,
+ op.getInnerDimsPos(), newMixedTileSizes.value(),
op.getPaddingValue(), op.getOuterDimsPerm());
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
@@ -6476,16 +6485,20 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
Value sourceTensor = newOperands[0];
// Get the updated mixed-tile-sizes attribute.
- SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
- rewriter, sourceTensor.getType(), op.getMixedTiles());
+ FailureOr<SmallVector<OpFoldResult>> newMixedTileSizes =
+ getNewMixedTileSizes(rewriter, sourceTensor.getType(), op.getMixedTiles());
+ if (failed(newMixedTileSizes))
+ return rewriter.notifyMatchFailure(
+ op, "unable to prove dynamic tile sizes after folding tensor.cast");
// Clone op.
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
// this point. However, in practice, we use them for things that we'd like
// to preserve. Implement a better abstraction.
- UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
- newOperands[1], op.getInnerDimsPos(),
- newMixedTileSizes, op.getOuterDimsPerm());
+ UnPackOp newOp =
+ UnPackOp::create(rewriter, op.getLoc(), sourceTensor, newOperands[1],
+ op.getInnerDimsPos(), newMixedTileSizes.value(),
+ op.getOuterDimsPerm());
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
// Replace op.
diff --git a/mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir b/mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir
new file mode 100644
index 0000000000000..eec3e3acc93fb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/canonicalize-dynamic-pack-unpack-tile.mlir
@@ -0,0 +1,149 @@
+// RUN: mlir-opt %s --inline -canonicalize="test-convergence" -split-input-file | FileCheck %s --check-prefixes=CHECK
+
+// CHECK: func.func @dynamic_tile_arg_no_fold
+// CHECK-SAME: %[[SRC:.+]]: tensor<1x3x8x1xi32>, %[[TILE:.+]]: index
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-DAG: %[[CAST:.+]] = tensor.cast %[[SRC]] : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
+// CHECK-SAME: inner_dims_pos = [0, 1]
+// CHECK-SAME: inner_tiles = [%[[TILE]], 1]
+// CHECK-SAME: into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+// CHECK: return %[[UNPACK]] : tensor<7x3xi32>
+module {
+ func.func @dynamic_tile_arg_no_fold(%arg0: tensor<1x3x8x1xi32>, %arg1: index) -> tensor<7x3xi32> {
+ %0 = tensor.empty() : tensor<7x3xi32>
+ %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+ %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%arg1, 1] into %0 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+ return %unpack : tensor<7x3xi32>
+ }
+}
+
+
+// -----
+
+// CHECK-LABEL: func.func @dynamic_tile_from_inlined_mismatch_no_fold
+// CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-DAG: %[[CAST:.+]] = tensor.cast %{{.+}} : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
+// CHECK-SAME: inner_dims_pos = [0, 1]
+// CHECK-SAME: inner_tiles = [%[[C256]], 1]
+// CHECK-SAME: into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+// CHECK: return %[[UNPACK]] : tensor<7x3xi32>
+module {
+ func.func @get_tile() -> index {
+ %c256 = arith.constant 256 : index
+ return %c256 : index
+ }
+ func.func @dynamic_tile_from_inlined_mismatch_no_fold(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
+ %0 = call @get_tile() : () -> index
+ %1 = tensor.empty() : tensor<7x3xi32>
+ %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+ %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%0, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+ return %unpack : tensor<7x3xi32>
+ }
+}
+
+
+// -----
+
+// CHECK-LABEL: func.func @constant_tile_from_inlined_match_folds
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-NOT: tensor.cast
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 1]
+// CHECK-SAME: into %[[EMPTY]] : tensor<1x3x8x1xi32> -> tensor<7x3xi32>
+// CHECK: return %[[UNPACK]] : tensor<7x3xi32>
+module {
+ func.func @get_tile() -> index {
+ %c8 = arith.constant 8 : index
+ return %c8 : index
+ }
+ func.func @constant_tile_from_inlined_match_folds(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
+ %0 = call @get_tile() : () -> index
+ %1 = tensor.empty() : tensor<7x3xi32>
+ %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+ %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%0, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+ return %unpack : tensor<7x3xi32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_dynamic_tile_arg
+// CHECK-SAME: %[[SRC:.+]]: tensor<8x3xi32>, %[[TILE:.+]]: index, %[[DEST:.+]]: tensor<?x3x?x1xi32>
+// CHECK: %[[PACK:.+]] = linalg.pack
+// CHECK: padding_value
+// CHECK: inner_dims_pos = [0, 1]
+// CHECK: inner_tiles = [%[[TILE]], 1]
+// CHECK: into %[[DEST]] : tensor
+// CHECK: return %[[PACK]] : tensor<?x3x?x1xi32>
+module {
+ func.func @pack_dynamic_tile_arg(%arg0: tensor<8x3xi32>, %arg1: index,
+ %dest: tensor<?x3x?x1xi32>) -> tensor<?x3x?x1xi32> {
+ %c0 = arith.constant 0 : i32
+ %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor<?x?xi32>
+ %pack = linalg.pack %cast
+ padding_value(%c0 : i32)
+ inner_dims_pos = [0, 1]
+ inner_tiles = [%arg1, 1]
+ into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
+ return %pack : tensor<?x3x?x1xi32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_dynamic_tile_from_inlined_mismatch
+// CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index
+// CHECK: %[[PACK:.+]] = linalg.pack
+// CHECK: padding_value
+// CHECK: inner_dims_pos = [0, 1]
+// CHECK: inner_tiles = [%[[C256]], 1]
+// CHECK: into %{{.+}} : tensor
+// CHECK: return %[[PACK]] : tensor<?x3x?x1xi32>
+module {
+ func.func @pack_get_tile() -> index {
+ %c256 = arith.constant 256 : index
+ return %c256 : index
+ }
+ func.func @pack_dynamic_tile_from_inlined_mismatch(%arg0: tensor<8x3xi32>,
+ %dest: tensor<?x3x?x1xi32>) -> tensor<?x3x?x1xi32> {
+ %c0 = arith.constant 0 : i32
+ %0 = call @pack_get_tile() : () -> index
+ %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor<?x?xi32>
+ %pack = linalg.pack %cast
+ padding_value(%c0 : i32)
+ inner_dims_pos = [0, 1]
+ inner_tiles = [%0, 1]
+ into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
+ return %pack : tensor<?x3x?x1xi32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_dynamic_tile_from_inlined_match_fold
+// CHECK: %[[PACK:.+]] = linalg.pack
+// CHECK: padding_value
+// CHECK: inner_dims_pos = [0, 1]
+// CHECK: inner_tiles = [%{{.+}}, 1]
+// CHECK: into %{{.+}} : tensor
+// CHECK: return %[[PACK]] : tensor<?x3x?x1xi32>
+module {
+ func.func @pack_get_tile() -> index {
+ %c8 = arith.constant 8 : index
+ return %c8 : index
+ }
+ func.func @pack_dynamic_tile_from_inlined_match_fold(%arg0: tensor<8x3xi32>,
+ %dest: tensor<?x3x?x1xi32>) -> tensor<?x3x?x1xi32> {
+ %c0 = arith.constant 0 : i32
+ %0 = call @pack_get_tile() : () -> index
+ %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor<?x?xi32>
+ %pack = linalg.pack %cast
+ padding_value(%c0 : i32)
+ inner_dims_pos = [0, 1]
+ inner_tiles = [%0, 1]
+ into %dest : tensor<?x?xi32> -> tensor<?x3x?x1xi32>
+ return %pack : tensor<?x3x?x1xi32>
+ }
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list