[Mlir-commits] [mlir] [mlir] Compose expand of collapse to cast (PR #172864)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 18 06:52:20 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Maya Amrami (amrami)

<details>
<summary>Changes</summary>

In some cases expand(collapse(x) pair cannot be folded into x, since it has different type than x.
In that case, it will be folded into cast.

This causes a change in memref::CastOp::areCastCompatible, where now a dim of size 1 may have different strides.

---
Full diff: https://github.com/llvm/llvm-project/pull/172864.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+9-3) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+6-3) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+1-1) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+19) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 6d4ea5b5136de..dea28268b932f 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -355,7 +355,7 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
   }
 };
 
-template <typename ExpandOpTy, typename CollapseOpTy>
+template <typename ExpandOpTy, typename CollapseOpTy, typename CastOpTy>
 struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
   using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(ExpandOpTy expandOp,
@@ -369,8 +369,14 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
 
     if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
         hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
-        hasNonIdentityLayout(collapseOp.getResult().getType()))
+        hasNonIdentityLayout(collapseOp.getResult().getType())) {
+      if (CastOpTy::areCastCompatible(srcType, resultType)) {
+        rewriter.replaceOpWithNewOp<CastOpTy>(expandOp, resultType,
+                                                    collapseOp.getSrc());
+        return success();
+      }
       return failure();
+    }
 
     int64_t srcRank = srcType.getRank();
     int64_t resultRank = resultType.getRank();
@@ -490,7 +496,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
 ///    %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
 ///          tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
 ///
-///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : 
+///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
 ///          tensor<1x1x1x10xf32> into tensor<1x10xf32>
 ///    %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
 ///          tensor<1x10xf32> into tensor<10x10xf32>
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..40d93d7308545 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -753,9 +753,12 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
       };
       if (!checkCompatible(aOffset, bOffset))
         return false;
-      for (const auto &aStride : enumerate(aStrides))
-        if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
+      for (const auto &[index, aStride] : enumerate(aStrides)) {
+        if (aT.getDimSize(index) == 1)
+          continue;
+        if (!checkCompatible(aStride, bStrides[index]))
           return false;
+      }
     }
     if (aT.getMemorySpace() != bT.getMemorySpace())
       return false;
@@ -2508,7 +2511,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
   results.add<
       ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
-      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>>(context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 204e9bb73e12c..c15d4ac29433a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2251,7 +2251,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
   results.add<
       ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
-      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
       ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
       FoldReshapeWithSplat<ExpandShapeOp>,
       FoldReshapeWithFromElements<ExpandShapeOp>>(context);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 60311306b984d..24f604099b799 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1193,6 +1193,25 @@ func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0
 
 // -----
 
+// CHECK-LABEL: func @expand_collapse_fold_to_cast(
+//  CHECK-SAME:     %[[m:.*]]: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+//       CHECK:   %[[casted:.*]] = memref.cast %[[m]] : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>> to memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+//       CHECK:   return %[[casted]]
+
+func.func @expand_collapse_fold_to_cast(%m: memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>)
+    -> (memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>)
+  {
+  %0 = memref.collapse_shape %m [[0, 1], [2], [3]]
+      : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+        into memref<1x384x384xui8, strided<[1179648, 384, 1]>>
+  %1 = memref.expand_shape %0 [[0, 1], [2], [3]] output_shape [1, 1, 384, 384]
+      : memref<1x384x384xui8, strided<[1179648, 384, 1]>>
+        into memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+  return %1 : memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_trivial_subviews(
 //  CHECK-SAME:     %[[m:.*]]: memref<?xf32, strided<[?], offset: ?>>
 //       CHECK:   %[[subview:.*]] = memref.subview %[[m]][5]

``````````

</details>


https://github.com/llvm/llvm-project/pull/172864


More information about the Mlir-commits mailing list