[Mlir-commits] [mlir] [mlir] Compose expand of collapse to cast (PR #172864)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 18 06:52:20 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maya Amrami (amrami)
<details>
<summary>Changes</summary>
In some cases expand(collapse(x) pair cannot be folded into x, since it has different type than x.
In that case, it will be folded into cast.
This causes a change in memref::CastOp::areCastCompatible, where now a dim of size 1 may have different strides.
---
Full diff: https://github.com/llvm/llvm-project/pull/172864.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+9-3)
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+6-3)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+1-1)
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+19)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 6d4ea5b5136de..dea28268b932f 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -355,7 +355,7 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
}
};
-template <typename ExpandOpTy, typename CollapseOpTy>
+template <typename ExpandOpTy, typename CollapseOpTy, typename CastOpTy>
struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandOpTy expandOp,
@@ -369,8 +369,14 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
- hasNonIdentityLayout(collapseOp.getResult().getType()))
+ hasNonIdentityLayout(collapseOp.getResult().getType())) {
+ if (CastOpTy::areCastCompatible(srcType, resultType)) {
+ rewriter.replaceOpWithNewOp<CastOpTy>(expandOp, resultType,
+ collapseOp.getSrc());
+ return success();
+ }
return failure();
+ }
int64_t srcRank = srcType.getRank();
int64_t resultRank = resultType.getRank();
@@ -490,7 +496,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
-/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..40d93d7308545 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -753,9 +753,12 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
};
if (!checkCompatible(aOffset, bOffset))
return false;
- for (const auto &aStride : enumerate(aStrides))
- if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
+ for (const auto &[index, aStride] : enumerate(aStrides)) {
+ if (aT.getDimSize(index) == 1)
+ continue;
+ if (!checkCompatible(aStride, bStrides[index]))
return false;
+ }
}
if (aT.getMemorySpace() != bT.getMemorySpace())
return false;
@@ -2508,7 +2511,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 204e9bb73e12c..c15d4ac29433a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2251,7 +2251,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
FoldReshapeWithSplat<ExpandShapeOp>,
FoldReshapeWithFromElements<ExpandShapeOp>>(context);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 60311306b984d..24f604099b799 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1193,6 +1193,25 @@ func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0
// -----
+// CHECK-LABEL: func @expand_collapse_fold_to_cast(
+// CHECK-SAME: %[[m:.*]]: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>> to memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+// CHECK: return %[[casted]]
+
+func.func @expand_collapse_fold_to_cast(%m: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>)
+ -> (memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>)
+ {
+ %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+ : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+ into memref<1x384x384xui8, strided<[1179648, 384, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [1, 1, 384, 384]
+ : memref<1x384x384xui8, strided<[1179648, 384, 1]>>
+ into memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+ return %1 : memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_trivial_subviews(
// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[?], offset: ?>>
// CHECK: %[[subview:.*]] = memref.subview %[[m]][5]
``````````
</details>
https://github.com/llvm/llvm-project/pull/172864
More information about the Mlir-commits
mailing list