[Mlir-commits] [mlir] [mlir] disable folding collapse expand to cast (PR #179209)
ofri frishman
llvmlistbot at llvm.org
Mon Feb 2 03:15:50 PST 2026
https://github.com/ofri-frishman created https://github.com/llvm/llvm-project/pull/179209
Collapsing expand(collapse(src)) to cast(src) is supported in cases where the source and result are cast compatible but not equal. When the source has dynamic dimensions this leads to cases where the cast is enabled even though certain dimensions cast from static to dynamic when the dynamic size is not assured to be equal to the static size.
Currently blocking applying this folding when the source has dynamic dimensions to preserve correctness.
In the future it could be possible to enable some cases of folding when not all dimensions of the source are static.
Such cases could be when:
1) expand and collapse happened on non dynamic dims
2) expand and collapse on dynamic dims could be folded to no op
>From 03102065e6911480a220d278fd25c21c7b2e904b Mon Sep 17 00:00:00 2001
From: Ofri Frishman <ofri4321 at gmail.com>
Date: Mon, 2 Feb 2026 12:06:53 +0200
Subject: [PATCH] [mlir] disable folding collapse expand to cast
Collpasing expand(collapse(src)) to cast(src) is supported in cases
where the source and result are cast compatible but not equal.
When the source has dynamic dimensions this leads to cases where
the cast is enabled even though certain dimensions cast from static
to dynamic when the dynamic size is assured to be equal to the static
size.
Currently blocking applying this folding when the source has dynamic
dimensions to preserve correctness.
In the future it could be possible to enable some cases of folding
when not all dimensions of the source are static.
Such cases could be when:
1) expand and collpase happened on non dynamic dims
2) expand and colapse on dynamic dims could be folded to nop
---
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 3 ++-
mlir/test/Dialect/MemRef/canonicalize.mlir | 15 +++++++++++++++
2 files changed, 17 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 64c125024d906..338804af6b1b4 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -370,7 +370,8 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getResult().getType())) {
- if (CastOpTy::areCastCompatible(srcType, resultType)) {
+ if (srcType.hasStaticShape() &&
+ CastOpTy::areCastCompatible(srcType, resultType)) {
rewriter.replaceOpWithNewOp<CastOpTy>(expandOp, resultType,
collapseOp.getSrc());
return success();
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 17afd9a15b60d..3cfea1e8cd961 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1385,6 +1385,21 @@ func.func @expand_collapse_do_not_fold_to_cast(%m: memref<1x3x2x384xui8, strided
// -----
+// CHECK-LABEL: func @expand_collapse_dynamic_do_not_fold_to_cast(
+// CHECK-NOT: memref.cast
+
+func.func @expand_collapse_dynamic_do_not_fold_to_cast(%m: memref<1x?x1x32xsi8, strided<[?, 32, 32, 1]>>, %dyn_size: index)
+ -> (memref<1x1x?x32xsi8, strided<[?, ?, 32, 1]>>)
+ {
+ %0 = memref.collapse_shape %m [[0], [1, 2], [3]]
+ : memref<1x?x1x32xsi8, strided<[?, 32, 32, 1]>> into memref<1x?x32xsi8, strided<[?, 32, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [1, 1, %dyn_size, 32]
+ : memref<1x?x32xsi8, strided<[?, 32, 1]>> into memref<1x1x?x32xsi8, strided<[?, ?, 32, 1]>>
+ return %1 : memref<1x1x?x32xsi8, strided<[?, ?, 32, 1]>>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_trivial_subviews(
// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[?], offset: ?>>
// CHECK: %[[subview:.*]] = memref.subview %[[m]][5]
More information about the Mlir-commits
mailing list