[Mlir-commits] [mlir] [mlir] Compose expand of collapse to cast (PR #172864)
Maya Amrami
llvmlistbot at llvm.org
Thu Dec 18 06:51:46 PST 2025
https://github.com/amrami created https://github.com/llvm/llvm-project/pull/172864
In some cases expand(collapse(x) pair cannot be folded into x, since it has different type than x.
In that case, it will be folded into cast.
This causes a change in memref::CastOp::areCastCompatible, where now a dim of size 1 may have different strides.
>From 3f8a66da78c86ad14d4d7708b59d532fd0e381c6 Mon Sep 17 00:00:00 2001
From: Maya Amrami <maya.amrami at mobileye.com>
Date: Thu, 18 Dec 2025 16:43:08 +0200
Subject: [PATCH] [mlir] Compose expand of collapse to cast
In some cases expand(collapse(x) pair cannot be folded
into x, since it has different type than x.
In that case, it will be folded into cast.
This causes a change in memref::CastOp::areCastCompatible,
where now a dim of size 1 may have different strides.
---
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 12 +++++++++---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 9 ++++++---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +-
mlir/test/Dialect/MemRef/canonicalize.mlir | 19 +++++++++++++++++++
4 files changed, 35 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 6d4ea5b5136de..dea28268b932f 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -355,7 +355,7 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
}
};
-template <typename ExpandOpTy, typename CollapseOpTy>
+template <typename ExpandOpTy, typename CollapseOpTy, typename CastOpTy>
struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandOpTy expandOp,
@@ -369,8 +369,14 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
- hasNonIdentityLayout(collapseOp.getResult().getType()))
+ hasNonIdentityLayout(collapseOp.getResult().getType())) {
+ if (CastOpTy::areCastCompatible(srcType, resultType)) {
+ rewriter.replaceOpWithNewOp<CastOpTy>(expandOp, resultType,
+ collapseOp.getSrc());
+ return success();
+ }
return failure();
+ }
int64_t srcRank = srcType.getRank();
int64_t resultRank = resultType.getRank();
@@ -490,7 +496,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
-/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..40d93d7308545 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -753,9 +753,12 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
};
if (!checkCompatible(aOffset, bOffset))
return false;
- for (const auto &aStride : enumerate(aStrides))
- if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
+ for (const auto &[index, aStride] : enumerate(aStrides)) {
+ if (aT.getDimSize(index) == 1)
+ continue;
+ if (!checkCompatible(aStride, bStrides[index]))
return false;
+ }
}
if (aT.getMemorySpace() != bT.getMemorySpace())
return false;
@@ -2508,7 +2511,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 204e9bb73e12c..c15d4ac29433a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2251,7 +2251,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
FoldReshapeWithSplat<ExpandShapeOp>,
FoldReshapeWithFromElements<ExpandShapeOp>>(context);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 60311306b984d..24f604099b799 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1193,6 +1193,25 @@ func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0
// -----
+// CHECK-LABEL: func @expand_collapse_fold_to_cast(
+// CHECK-SAME: %[[m:.*]]: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>> to memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+// CHECK: return %[[casted]]
+
+func.func @expand_collapse_fold_to_cast(%m: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>)
+ -> (memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>)
+ {
+ %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+ : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+ into memref<1x384x384xui8, strided<[1179648, 384, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [1, 1, 384, 384]
+ : memref<1x384x384xui8, strided<[1179648, 384, 1]>>
+ into memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+ return %1 : memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_trivial_subviews(
// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[?], offset: ?>>
// CHECK: %[[subview:.*]] = memref.subview %[[m]][5]
More information about the Mlir-commits
mailing list