[Mlir-commits] [mlir] [mlir] Compose expand of collapse to cast (PR #172864)
Maya Amrami
llvmlistbot at llvm.org
Sun Dec 21 02:16:17 PST 2025
https://github.com/amrami updated https://github.com/llvm/llvm-project/pull/172864
>From dde4eaeb4a5df41d0fb92ecae0ac827c83cec0a4 Mon Sep 17 00:00:00 2001
From: Maya Amrami <maya.amrami at mobileye.com>
Date: Thu, 18 Dec 2025 16:43:08 +0200
Subject: [PATCH 1/2] [mlir] Compose expand of collapse to cast
In some cases expand(collapse(x) pair cannot be folded
into x, since it has different type than x.
In that case, it will be folded into cast.
This causes a change in memref::CastOp::areCastCompatible,
where now a dim of size 1 may have different strides.
---
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 12 +++++++++---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 10 +++++++---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +-
mlir/test/Dialect/MemRef/canonicalize.mlir | 19 +++++++++++++++++++
4 files changed, 36 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 6d4ea5b5136de..64c125024d906 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -355,7 +355,7 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
}
};
-template <typename ExpandOpTy, typename CollapseOpTy>
+template <typename ExpandOpTy, typename CollapseOpTy, typename CastOpTy>
struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandOpTy expandOp,
@@ -369,8 +369,14 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
- hasNonIdentityLayout(collapseOp.getResult().getType()))
+ hasNonIdentityLayout(collapseOp.getResult().getType())) {
+ if (CastOpTy::areCastCompatible(srcType, resultType)) {
+ rewriter.replaceOpWithNewOp<CastOpTy>(expandOp, resultType,
+ collapseOp.getSrc());
+ return success();
+ }
return failure();
+ }
int64_t srcRank = srcType.getRank();
int64_t resultRank = resultType.getRank();
@@ -490,7 +496,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
-/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..72df34f4481da 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -753,9 +753,12 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
};
if (!checkCompatible(aOffset, bOffset))
return false;
- for (const auto &aStride : enumerate(aStrides))
- if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
+ for (const auto &[index, aStride] : enumerate(aStrides)) {
+ if (aT.getDimSize(index) == 1)
+ continue;
+ if (!checkCompatible(aStride, bStrides[index]))
return false;
+ }
}
if (aT.getMemorySpace() != bT.getMemorySpace())
return false;
@@ -2508,7 +2511,8 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>>(
+ context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 204e9bb73e12c..c15d4ac29433a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2251,7 +2251,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
FoldReshapeWithSplat<ExpandShapeOp>,
FoldReshapeWithFromElements<ExpandShapeOp>>(context);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 60311306b984d..24f604099b799 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1193,6 +1193,25 @@ func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0
// -----
+// CHECK-LABEL: func @expand_collapse_fold_to_cast(
+// CHECK-SAME: %[[m:.*]]: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>> to memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+// CHECK: return %[[casted]]
+
+func.func @expand_collapse_fold_to_cast(%m: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>)
+ -> (memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>)
+ {
+ %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+ : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+ into memref<1x384x384xui8, strided<[1179648, 384, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [1, 1, 384, 384]
+ : memref<1x384x384xui8, strided<[1179648, 384, 1]>>
+ into memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+ return %1 : memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_trivial_subviews(
// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[?], offset: ?>>
// CHECK: %[[subview:.*]] = memref.subview %[[m]][5]
>From 7ad26c22767320419ada0522388a3acec670b36d Mon Sep 17 00:00:00 2001
From: Maya Amrami <maya.amrami at mobileye.com>
Date: Sun, 21 Dec 2025 12:16:04 +0200
Subject: [PATCH 2/2] CR
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 1 +
mlir/test/Dialect/MemRef/canonicalize.mlir | 37 ++++++++++++++++------
2 files changed, 28 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 72df34f4481da..f271a8db39e29 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -748,6 +748,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
// source memref is static and the value in the target memref is the
// same. They are also compatible if either one is dynamic (see
// description of MemRefCastOp for details).
+ // Note that for dimensions of size 1, the stride can differ.
auto checkCompatible = [](int64_t a, int64_t b) {
return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
};
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 24f604099b799..330c5743c3262 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1194,20 +1194,37 @@ func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0
// -----
// CHECK-LABEL: func @expand_collapse_fold_to_cast(
-// CHECK-SAME: %[[m:.*]]: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
-// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>> to memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+// CHECK-SAME: %[[m:.*]]: memref<1x1x2x384xui8, strided<[1179648, 768, 384, 1]>>
+// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<1x1x2x384xui8, strided<[1179648, 768, 384, 1]>> to memref<1x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>
// CHECK: return %[[casted]]
-func.func @expand_collapse_fold_to_cast(%m: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>)
- -> (memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>)
+func.func @expand_collapse_fold_to_cast(%m: memref<1x1x2x384xui8, strided<[1179648, 768, 384, 1]>>)
+ -> (memref<1x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>)
{
%0 = memref.collapse_shape %m [[0, 1], [2], [3]]
- : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
- into memref<1x384x384xui8, strided<[1179648, 384, 1]>>
- %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [1, 1, 384, 384]
- : memref<1x384x384xui8, strided<[1179648, 384, 1]>>
- into memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
- return %1 : memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+ : memref<1x1x2x384xui8, strided<[1179648, 768, 384, 1]>>
+ into memref<1x2x384xui8, strided<[1179648, 384, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [1, 1, 2, 384]
+ : memref<1x2x384xui8, strided<[1179648, 384, 1]>>
+ into memref<1x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>
+ return %1 : memref<1x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @expand_collapse_do_not_fold_to_cast(
+// CHECK-NOT: memref.cast
+
+func.func @expand_collapse_do_not_fold_to_cast(%m: memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>)
+ -> (memref<3x1x2x384xui8, strided<[768, 768, 384, 1]>>)
+ {
+ %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+ : memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>
+ into memref<3x2x384xui8, strided<[768, 384, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [3, 1, 2, 384]
+ : memref<3x2x384xui8, strided<[768, 384, 1]>>
+ into memref<3x1x2x384xui8, strided<[768, 768, 384, 1]>>
+ return %1 : memref<3x1x2x384xui8, strided<[768, 768, 384, 1]>>
}
// -----
More information about the Mlir-commits
mailing list