[Mlir-commits] [mlir] 4ccf926 - [mlir] Compose expand of collapse to cast (#172864)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 13 01:57:28 PST 2026


Author: Maya Amrami
Date: 2026-01-13T11:57:23+02:00
New Revision: 4ccf926e7f075724e3206f92623c8e00d1d34417

URL: https://github.com/llvm/llvm-project/commit/4ccf926e7f075724e3206f92623c8e00d1d34417
DIFF: https://github.com/llvm/llvm-project/commit/4ccf926e7f075724e3206f92623c8e00d1d34417.diff

LOG: [mlir] Compose expand of collapse to cast (#172864)

In some cases `y = expand(collapse(x))` cannot be folded into x, since x
and y have different types.
In that case, we check if the two types are cast compatible.
If they are, it means the two types have compatible shape and layout and
y can be folded into cast(x).

This causes a change in memref::CastOp::areCastCompatible, where now a
dim of size 1 may have different strides.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 6d4ea5b5136de..64c125024d906 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -355,7 +355,7 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
   }
 };
 
-template <typename ExpandOpTy, typename CollapseOpTy>
+template <typename ExpandOpTy, typename CollapseOpTy, typename CastOpTy>
 struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
   using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(ExpandOpTy expandOp,
@@ -369,8 +369,14 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
 
     if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
         hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
-        hasNonIdentityLayout(collapseOp.getResult().getType()))
+        hasNonIdentityLayout(collapseOp.getResult().getType())) {
+      if (CastOpTy::areCastCompatible(srcType, resultType)) {
+        rewriter.replaceOpWithNewOp<CastOpTy>(expandOp, resultType,
+                                              collapseOp.getSrc());
+        return success();
+      }
       return failure();
+    }
 
     int64_t srcRank = srcType.getRank();
     int64_t resultRank = resultType.getRank();
@@ -490,7 +496,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
 ///    %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
 ///          tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
 ///
-///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : 
+///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
 ///          tensor<1x1x1x10xf32> into tensor<1x10xf32>
 ///    %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
 ///          tensor<1x10xf32> into tensor<10x10xf32>

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a612475edf3b9..13310c59f9682 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -755,14 +755,18 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
       // source memref is static and the value in the target memref is the
       // same. They are also compatible if either one is dynamic (see
       // description of MemRefCastOp for details).
+      // Note that for dimensions of size 1, the stride can 
diff er.
       auto checkCompatible = [](int64_t a, int64_t b) {
         return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
       };
       if (!checkCompatible(aOffset, bOffset))
         return false;
-      for (const auto &aStride : enumerate(aStrides))
-        if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
+      for (const auto &[index, aStride] : enumerate(aStrides)) {
+        if (aT.getDimSize(index) == 1)
+          continue;
+        if (!checkCompatible(aStride, bStrides[index]))
           return false;
+      }
     }
     if (aT.getMemorySpace() != bT.getMemorySpace())
       return false;
@@ -2580,7 +2584,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
   results.add<
       ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
-      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
       ExpandShapeOpMemRefCastFolder>(context);
 }
 

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 11824a0ac6f05..05db7d0dd33ee 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2255,7 +2255,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
   results.add<
       ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
-      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
       ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
       FoldReshapeWithSplat<ExpandShapeOp>,
       FoldReshapeWithFromElements<ExpandShapeOp>>(context);

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 132acfd9b1d48..122906037b952 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1331,6 +1331,60 @@ func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0
 
 // -----
 
+// CHECK-LABEL: func @expand_collapse_fold_to_internal_stride_cast(
+//  CHECK-SAME:     %[[m:.*]]: memref<3x1x2x384xui8, strided<[1179648, 768, 384, 1]>>
+//       CHECK:   %[[casted:.*]] = memref.cast %[[m]] : memref<3x1x2x384xui8, strided<[1179648, 768, 384, 1]>> to memref<3x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>
+
+func.func @expand_collapse_fold_to_internal_stride_cast(%m: memref<3x1x2x384xui8, strided<[1179648, 768, 384, 1]>>)
+    -> (memref<3x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>)
+  {
+  %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+      : memref<3x1x2x384xui8, strided<[1179648, 768, 384, 1]>>
+        into memref<3x2x384xui8, strided<[1179648, 384, 1]>>
+  %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [3, 1, 2, 384]
+      : memref<3x2x384xui8, strided<[1179648, 384, 1]>>
+        into memref<3x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>
+  return %1 : memref<3x1x2x384xui8, strided<[1179648, 1179648, 384, 1]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @expand_collapse_fold_to_outermost_stride_cast(
+//  CHECK-SAME:     %[[m:.*]]: memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>
+//       CHECK:   %[[casted:.*]] = memref.cast %[[m]] : memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>> to memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>
+//       CHECK:   return %[[casted]]
+
+func.func @expand_collapse_fold_to_outermost_stride_cast(%m: memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>)
+    -> (memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>)
+  {
+  %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+      : memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>
+        into memref<3x2x384xui8, strided<[768, 384, 1]>>
+  %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [1, 3, 2, 384]
+      : memref<3x2x384xui8, strided<[768, 384, 1]>>
+        into memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>
+  return %1 : memref<1x3x2x384xui8, strided<[2304, 768, 384, 1]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @expand_collapse_do_not_fold_to_cast(
+//   CHECK-NOT:   memref.cast
+
+func.func @expand_collapse_do_not_fold_to_cast(%m: memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>)
+    -> (memref<3x1x2x384xui8, strided<[768, 768, 384, 1]>>)
+  {
+  %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+      : memref<1x3x2x384xui8, strided<[1179648, 768, 384, 1]>>
+        into memref<3x2x384xui8, strided<[768, 384, 1]>>
+  %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [3, 1, 2, 384]
+      : memref<3x2x384xui8, strided<[768, 384, 1]>>
+        into memref<3x1x2x384xui8, strided<[768, 768, 384, 1]>>
+  return %1 : memref<3x1x2x384xui8, strided<[768, 768, 384, 1]>>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_trivial_subviews(
 //  CHECK-SAME:     %[[m:.*]]: memref<?xf32, strided<[?], offset: ?>>
 //       CHECK:   %[[subview:.*]] = memref.subview %[[m]][5]


        


More information about the Mlir-commits mailing list