[Mlir-commits] [mlir] 4ccf926 - [mlir] Compose expand of collapse to cast (#172864)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 13 01:57:28 PST 2026
Author: Maya Amrami
Date: 2026-01-13T11:57:23+02:00
New Revision: 4ccf926e7f075724e3206f92623c8e00d1d34417
URL: https://github.com/llvm/llvm-project/commit/4ccf926e7f075724e3206f92623c8e00d1d34417
DIFF: https://github.com/llvm/llvm-project/commit/4ccf926e7f075724e3206f92623c8e00d1d34417.diff
LOG: [mlir] Compose expand of collapse to cast (#172864)
In some cases `y = expand(collapse(x))` cannot be folded into x, since x
and y have different types.
In that case, we check if the two types are cast compatible.
If they are, it means the two types have compatible shape and layout and
y can be folded into cast(x).
This causes a change in memref::CastOp::areCastCompatible, where now a
dim of size 1 may have different strides.
Added:
Modified:
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 6d4ea5b5136de..64c125024d906 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -355,7 +355,7 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
}
};
-template <typename ExpandOpTy, typename CollapseOpTy>
+template <typename ExpandOpTy, typename CollapseOpTy, typename CastOpTy>
struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandOpTy expandOp,
@@ -369,8 +369,14 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
- hasNonIdentityLayout(collapseOp.getResult().getType()))
+ hasNonIdentityLayout(collapseOp.getResult().getType())) {
+ if (CastOpTy::areCastCompatible(srcType, resultType)) {
+ rewriter.replaceOpWithNewOp<CastOpTy>(expandOp, resultType,
+ collapseOp.getSrc());
+ return success();
+ }
return failure();
+ }
int64_t srcRank = srcType.getRank();
int64_t resultRank = resultType.getRank();
@@ -490,7 +496,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
-/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a612475edf3b9..13310c59f9682 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -755,14 +755,18 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
// source memref is static and the value in the target memref is the
// same. They are also compatible if either one is dynamic (see
// description of MemRefCastOp for details).
+ // Note that for dimensions of size 1, the stride can
diff er.
auto checkCompatible = [](int64_t a, int64_t b) {
return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
};
if (!checkCompatible(aOffset, bOffset))
return false;
- for (const auto &aStride : enumerate(aStrides))
- if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
+ for (const auto &[index, aStride] : enumerate(aStrides)) {
+ if (aT.getDimSize(index) == 1)
+ continue;
+ if (!checkCompatible(aStride, bStrides[index]))
return false;
+ }
}
if (aT.getMemorySpace() != bT.getMemorySpace())
return false;
@@ -2580,7 +2584,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
ExpandShapeOpMemRefCastFolder>(context);
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 11824a0ac6f05..05db7d0dd33ee 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2255,7 +2255,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
FoldReshapeWithSplat<ExpandShapeOp>,
FoldReshapeWithFromElements<ExpandShapeOp>>(context);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 132acfd9b1d48..122906037b952 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1331,6 +1331,60 @@ func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0
// -----
+// CHECK-LABEL: func @expand_collapse_fold_to_internal_stride_cast(
+// CHECK-SAME: %[[m:.*]]: memref<3x1x2x384xui8, strided<[1179648, 768, 384, 1]>>
+// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<3x1x2x384xui8, strided<[1179648, 768, 384, 1]>> to memref<3x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>
+
+func.func @expand_collapse_fold_to_internal_stride_cast(%m: memref<3x1x2x384xui8, strided<[1179648, 768, 384, 1]>>)
+ -> (memref<3x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>)
+ {
+ %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+ : memref<3x1x2x384xui8, strided<[1179648, 768, 384, 1]>>
+ into memref<3x2x384xui8, strided<[1179648, 384, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [3, 1, 2, 384]
+ : memref<3x2x384xui8, strided<[1179648, 384, 1]>>
+ into memref<3x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>
+ return %1 : memref<3x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @expand_collapse_fold_to_outermost_stride_cast(
+// CHECK-SAME: %[[m:.*]]: memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>
+// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>> to memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>
+// CHECK: return %[[casted]]
+
+func.func @expand_collapse_fold_to_outermost_stride_cast(%m: memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>)
+ -> (memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>)
+ {
+ %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+ : memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>
+ into memref<3x2x384xui8, strided<[768, 384, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [1, 3, 2, 384]
+ : memref<3x2x384xui8, strided<[768, 384, 1]>>
+ into memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>
+ return %1 : memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @expand_collapse_do_not_fold_to_cast(
+// CHECK-NOT: memref.cast
+
+func.func @expand_collapse_do_not_fold_to_cast(%m: memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>)
+ -> (memref<3x1x2x384xui8, strided<[768, 768, 384, 1]>>)
+ {
+ %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+ : memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>
+ into memref<3x2x384xui8, strided<[768, 384, 1]>>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [3, 1, 2, 384]
+ : memref<3x2x384xui8, strided<[768, 384, 1]>>
+ into memref<3x1x2x384xui8, strided<[768, 768, 384, 1]>>
+ return %1 : memref<3x1x2x384xui8, strided<[768, 768, 384, 1]>>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_trivial_subviews(
// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[?], offset: ?>>
// CHECK: %[[subview:.*]] = memref.subview %[[m]][5]
More information about the Mlir-commits
mailing list