[Mlir-commits] [mlir] [mlir] disable folding collapse expand to cast (PR #179209)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 2 03:16:21 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: ofri frishman (ofri-frishman)
<details>
<summary>Changes</summary>
Collapsing expand(collapse(src)) to cast(src) is supported in cases where the source and result are cast compatible but not equal. When the source has dynamic dimensions this leads to cases where the cast is enabled even though certain dimensions cast from static to dynamic when the dynamic size is not assured to be equal to the static size.
Currently blocking applying this folding when the source has dynamic dimensions to preserve correctness.
In the future it could be possible to enable some cases of folding when not all dimensions of the source are static.
Such cases could be when:
1) expand and collapse happened on non dynamic dims
2) expand and collapse on dynamic dims could be folded to no op
---
Full diff: https://github.com/llvm/llvm-project/pull/179209.diff
2 Files Affected:
- (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+2-1)
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+15)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 64c125024d906..338804af6b1b4 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -370,7 +370,8 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getResult().getType())) {
- if (CastOpTy::areCastCompatible(srcType, resultType)) {
+ if (srcType.hasStaticShape() &&
+ CastOpTy::areCastCompatible(srcType, resultType)) {
rewriter.replaceOpWithNewOp<CastOpTy>(expandOp, resultType,
collapseOp.getSrc());
return success();
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 17afd9a15b60d..3cfea1e8cd961 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1385,6 +1385,21 @@ func.func @expand_collapse_do_not_fold_to_cast(%m: memref<1x3x2x384xui8, strided
// -----
+// CHECK-LABEL: func @expand_collapse_dynamic_do_not_fold_to_cast(
+// CHECK-NOT: memref.cast
+
+func.func @expand_collapse_dynamic_do_not_fold_to_cast(%m: memref<1x?x1x32xsi8, strided<[?, 32, 32, 1]>>, %dyn_size: index)
+ -> (memref<1x1x?x32xsi8, strided<[?, ?, 32, 1]>>)
+ {
+ %0 = memref.collapse_shape %m [[0], [1, 2], [3]]
+ : memref<1x?x1x32xsi8, strided<[?, 32, 32, 1]>> into memref<1x?x32xsi8, strided<[?, 32, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [1, 1, %dyn_size, 32]
+ : memref<1x?x32xsi8, strided<[?, 32, 1]>> into memref<1x1x?x32xsi8, strided<[?, ?, 32, 1]>>
+ return %1 : memref<1x1x?x32xsi8, strided<[?, ?, 32, 1]>>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_trivial_subviews(
// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[?], offset: ?>>
// CHECK: %[[subview:.*]] = memref.subview %[[m]][5]
``````````
</details>
https://github.com/llvm/llvm-project/pull/179209
More information about the Mlir-commits
mailing list