[Mlir-commits] [mlir] [mlir] Compose expand of collapse to cast (PR #172864)
Maya Amrami
llvmlistbot at llvm.org
Wed Dec 24 00:16:52 PST 2025
https://github.com/amrami updated https://github.com/llvm/llvm-project/pull/172864
>From dde4eaeb4a5df41d0fb92ecae0ac827c83cec0a4 Mon Sep 17 00:00:00 2001
From: Maya Amrami <maya.amrami at mobileye.com>
Date: Thu, 18 Dec 2025 16:43:08 +0200
Subject: [PATCH 1/3] [mlir] Compose expand of collapse to cast
In some cases expand(collapse(x) pair cannot be folded
into x, since it has different type than x.
In that case, it will be folded into cast.
This causes a change in memref::CastOp::areCastCompatible,
where now a dim of size 1 may have different strides.
---
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 12 +++++++++---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 10 +++++++---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +-
mlir/test/Dialect/MemRef/canonicalize.mlir | 19 +++++++++++++++++++
4 files changed, 36 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 6d4ea5b5136de..64c125024d906 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -355,7 +355,7 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
}
};
-template <typename ExpandOpTy, typename CollapseOpTy>
+template <typename ExpandOpTy, typename CollapseOpTy, typename CastOpTy>
struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandOpTy expandOp,
@@ -369,8 +369,14 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
- hasNonIdentityLayout(collapseOp.getResult().getType()))
+ hasNonIdentityLayout(collapseOp.getResult().getType())) {
+ if (CastOpTy::areCastCompatible(srcType, resultType)) {
+ rewriter.replaceOpWithNewOp<CastOpTy>(expandOp, resultType,
+ collapseOp.getSrc());
+ return success();
+ }
return failure();
+ }
int64_t srcRank = srcType.getRank();
int64_t resultRank = resultType.getRank();
@@ -490,7 +496,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
-/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..72df34f4481da 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -753,9 +753,12 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
};
if (!checkCompatible(aOffset, bOffset))
return false;
- for (const auto &aStride : enumerate(aStrides))
- if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
+ for (const auto &[index, aStride] : enumerate(aStrides)) {
+ if (aT.getDimSize(index) == 1)
+ continue;
+ if (!checkCompatible(aStride, bStrides[index]))
return false;
+ }
}
if (aT.getMemorySpace() != bT.getMemorySpace())
return false;
@@ -2508,7 +2511,8 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>>(
+ context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 204e9bb73e12c..c15d4ac29433a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2251,7 +2251,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
FoldReshapeWithSplat<ExpandShapeOp>,
FoldReshapeWithFromElements<ExpandShapeOp>>(context);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 60311306b984d..24f604099b799 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1193,6 +1193,25 @@ func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0
// -----
+// CHECK-LABEL: func @expand_collapse_fold_to_cast(
+// CHECK-SAME: %[[m:.*]]: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>> to memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+// CHECK: return %[[casted]]
+
+func.func @expand_collapse_fold_to_cast(%m: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>)
+ -> (memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>)
+ {
+ %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+ : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+ into memref<1x384x384xui8, strided<[1179648, 384, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [1, 1, 384, 384]
+ : memref<1x384x384xui8, strided<[1179648, 384, 1]>>
+ into memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+ return %1 : memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_trivial_subviews(
// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[?], offset: ?>>
// CHECK: %[[subview:.*]] = memref.subview %[[m]][5]
>From b45ba167a37ebcca40b3785a2783c2a0d2f5b8c8 Mon Sep 17 00:00:00 2001
From: Maya Amrami <maya.amrami at mobileye.com>
Date: Sun, 21 Dec 2025 12:16:04 +0200
Subject: [PATCH 2/3] CR
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 1 +
mlir/test/Dialect/MemRef/canonicalize.mlir | 37 ++++++++++++++++------
2 files changed, 28 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 72df34f4481da..f271a8db39e29 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -748,6 +748,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
// source memref is static and the value in the target memref is the
// same. They are also compatible if either one is dynamic (see
// description of MemRefCastOp for details).
+ // Note that for dimensions of size 1, the stride can differ.
auto checkCompatible = [](int64_t a, int64_t b) {
return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
};
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 24f604099b799..0878842bb2b05 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1194,20 +1194,37 @@ func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0
// -----
// CHECK-LABEL: func @expand_collapse_fold_to_cast(
-// CHECK-SAME: %[[m:.*]]: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
-// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>> to memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+// CHECK-SAME: %[[m:.*]]: memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>
+// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>> to memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>
// CHECK: return %[[casted]]
-func.func @expand_collapse_fold_to_cast(%m: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>)
- -> (memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>)
+func.func @expand_collapse_fold_to_cast(%m: memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>)
+ -> (memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>)
{
%0 = memref.collapse_shape %m [[0, 1], [2], [3]]
- : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
- into memref<1x384x384xui8, strided<[1179648, 384, 1]>>
- %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [1, 1, 384, 384]
- : memref<1x384x384xui8, strided<[1179648, 384, 1]>>
- into memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
- return %1 : memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+ : memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>
+ into memref<3x2x384xui8, strided<[768, 384, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [1, 3, 2, 384]
+ : memref<3x2x384xui8, strided<[768, 384, 1]>>
+ into memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>
+ return %1 : memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @expand_collapse_do_not_fold_to_cast(
+// CHECK-NOT: memref.cast
+
+func.func @expand_collapse_do_not_fold_to_cast(%m: memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>)
+ -> (memref<3x1x2x384xui8, strided<[768, 768, 384, 1]>>)
+ {
+ %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+ : memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>
+ into memref<3x2x384xui8, strided<[768, 384, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [3, 1, 2, 384]
+ : memref<3x2x384xui8, strided<[768, 384, 1]>>
+ into memref<3x1x2x384xui8, strided<[768, 768, 384, 1]>>
+ return %1 : memref<3x1x2x384xui8, strided<[768, 768, 384, 1]>>
}
// -----
>From f181026b19e05313409d6acc4f9fee8ddb065b88 Mon Sep 17 00:00:00 2001
From: Maya Amrami <maya.amrami at mobileye.com>
Date: Wed, 24 Dec 2025 10:16:25 +0200
Subject: [PATCH 3/3] CR 2 - Adding a lit test where the mismatch in strides is
in an internal dim
---
mlir/test/Dialect/MemRef/canonicalize.mlir | 22 ++++++++++++++++++++--
1 file changed, 20 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 0878842bb2b05..3e1f9965f2495 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1193,12 +1193,30 @@ func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0
// -----
-// CHECK-LABEL: func @expand_collapse_fold_to_cast(
+// CHECK-LABEL: func @expand_collapse_fold_to_internal_stride_cast(
+// CHECK-SAME: %[[m:.*]]: memref<3x1x2x384xui8, strided<[1179648, 768, 384, 1]>>
+// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<3x1x2x384xui8, strided<[1179648, 768, 384, 1]>> to memref<3x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>
+
+func.func @expand_collapse_fold_to_internal_stride_cast(%m: memref<3x1x2x384xui8, strided<[1179648, 768, 384, 1]>>)
+ -> (memref<3x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>)
+ {
+ %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+ : memref<3x1x2x384xui8, strided<[1179648, 768, 384, 1]>>
+ into memref<3x2x384xui8, strided<[1179648, 384, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [3, 1, 2, 384]
+ : memref<3x2x384xui8, strided<[1179648, 384, 1]>>
+ into memref<3x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>
+ return %1 : memref<3x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @expand_collapse_fold_to_outermost_stride_cast(
// CHECK-SAME: %[[m:.*]]: memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>
// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>> to memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>
// CHECK: return %[[casted]]
-func.func @expand_collapse_fold_to_cast(%m: memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>)
+func.func @expand_collapse_fold_to_outermost_stride_cast(%m: memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>)
-> (memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>)
{
%0 = memref.collapse_shape %m [[0, 1], [2], [3]]
More information about the Mlir-commits
mailing list