[Mlir-commits] [mlir] [mlir][tensor] Fold EmptyOp & Collapse/ExpandShapeOp to EmptyOp (PR #175437)

Maya Amrami llvmlistbot at llvm.org
Sun Jan 11 06:22:12 PST 2026


https://github.com/amrami created https://github.com/llvm/llvm-project/pull/175437

None

>From 9e16276cc1b82625a3225f5efe3af306f2ee2083 Mon Sep 17 00:00:00 2001
From: Maya Amrami <maya.amrami at mobileye.com>
Date: Sun, 11 Jan 2026 16:21:21 +0200
Subject: [PATCH] [mlir][tensor] Fold EmptyOp & Collapse/ExpandShapeOp to
 EmptyOp

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   | 23 +++++++++++++++-
 mlir/test/Dialect/Linalg/canonicalize.mlir |  8 +++---
 mlir/test/Dialect/Tensor/canonicalize.mlir | 32 +++++++++++++++++++---
 3 files changed, 54 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a0c7e40c20a46..d085e7cb72c5c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1244,12 +1244,33 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
   }
 };
 
+template <typename T>
+struct FoldEmptyTensorWithCollapseExpandOp : public OpRewritePattern<T> {
+  using OpRewritePattern<T>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(T op,
+                                PatternRewriter &rewriter) const override {
+    auto producer = op.getSrc().template getDefiningOp<EmptyOp>();
+    if (!producer)
+      return failure();
+    if (!producer.getType().hasStaticShape())
+      return failure();
+
+    auto resultType = cast<RankedTensorType>(op.getResultType());
+    rewriter.replaceOpWithNewOp<EmptyOp>(op, resultType.getShape(),
+                                         resultType.getElementType());
+    return success();
+  }
+};
+
 } // namespace
 
 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {
   results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
-              ReplaceEmptyTensorStaticShapeDims>(context);
+              ReplaceEmptyTensorStaticShapeDims,
+              FoldEmptyTensorWithCollapseExpandOp<CollapseShapeOp>,
+              FoldEmptyTensorWithCollapseExpandOp<ExpandShapeOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index f4020ede4854e..f2a708b1d747d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -389,9 +389,9 @@ func.func @fold_linalg_index_memref(%0: memref<1x?xi32>, %1: memref<1x?xi32>) {
 func.func @fold_fill_reshape() -> tensor<6x4xf32> {
   %zero = arith.constant 0.0 : f32
   %empty = tensor.empty() : tensor<1x2x3x4xf32>
-  // CHECK:      %[[COLLAPSE:.+]] = tensor.collapse_shape
+  // CHECK:      %[[COLLAPSED_EMPTY:.+]] = tensor.empty()
   // CHECK-NEXT: %[[FILL:.+]] = linalg.fill ins(%cst : f32)
-  // CHECK-SAME:   outs(%[[COLLAPSE]] : tensor<6x4xf32>)
+  // CHECK-SAME:   outs(%[[COLLAPSED_EMPTY]] : tensor<6x4xf32>)
   %fill = linalg.fill ins(%zero : f32) outs(%empty : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
   %reshape = tensor.collapse_shape %fill [[0, 1, 2], [3]]
       : tensor<1x2x3x4xf32> into tensor<6x4xf32>
@@ -512,7 +512,7 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 // -----
 
 // CHECK-LABEL: func @no_fold_fill_like_memref
-//  CHECK-NEXT:   linalg.generic 
+//  CHECK-NEXT:   linalg.generic
 func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32) {
   linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                    affine_map<(d0, d1) -> (d0, d1)>],
@@ -528,7 +528,7 @@ func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32)
 // -----
 
 // CHECK-LABEL: func @no_fold_fill_like_tensor
-//  CHECK-NEXT:   linalg.generic 
+//  CHECK-NEXT:   linalg.generic
 func.func @no_fold_fill_like_tensor(%in_out : tensor<4x16xf32>, %fill_val : f32) -> tensor<4x16xf32> {
   %result = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                    affine_map<(d0, d1) -> (d0, d1)>],
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 95c5b8c91edf5..fe0ea58282149 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2218,6 +2218,30 @@ func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> {
 
 // -----
 
+func.func @fold_empty_tensor_with_collapse() -> tensor<12xf32> {
+  %0 = tensor.empty() : tensor<1x12xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1]]: tensor<1x12xf32> into tensor<12xf32>
+  return %1 : tensor<12xf32>
+}
+
+//      CHECK: func @fold_empty_tensor_with_collapse()
+//      CHECK:   %[[T0:.+]] = tensor.empty() : tensor<12xf32>
+//      CHECK:   return %[[T0]] : tensor<12xf32>
+
+// -----
+
+func.func @fold_empty_tensor_with_expand() -> tensor<1x12xf32> {
+  %0 = tensor.empty() : tensor<12xf32>
+  %1 = tensor.expand_shape %0 [[0, 1]] output_shape [1, 12] : tensor<12xf32> into tensor<1x12xf32>
+  return %1 : tensor<1x12xf32>
+}
+
+//      CHECK: func @fold_empty_tensor_with_expand()
+//      CHECK:   %[[T0:.+]] = tensor.empty() : tensor<1x12xf32>
+//      CHECK:   return %[[T0]] : tensor<1x12xf32>
+
+// -----
+
 func.func private @some_use(%i : index, %j : index)
 
 // CHECK-LABEL: func @empty_tensor_canonicalize
@@ -2523,8 +2547,8 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
 
 func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
     -> tensor<10x1x10xf32> {
-  %c1 = arith.constant 1 : index 
-  %c10 = arith.constant 10 : index 
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
   %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
   %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
       : tensor<?x?xf32> into tensor<?x?x?xf32>
@@ -2549,7 +2573,7 @@ func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
 // CHECK-LABEL:  func.func @sink_expand_of_cast
 //   CHECK-DAG:   %[[C10:.*]] = arith.constant 10
 //   CHECK-DAG:   %[[C1:.*]] = arith.constant 1
-//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] 
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     output_shape [%[[C10]], %[[C1]], 10]
 //       CHECK:   %[[RES:.+]] = tensor.cast %[[EXPAND]]
 //       CHECK:   return %[[RES]]
@@ -2567,7 +2591,7 @@ func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index,
 // CHECK-LABEL:  func.func @partial_sink_expand_of_cast
 //       CHECK:   %[[CAST:.+]] = tensor.cast
 //  CHECK-SAME:     tensor<10x10xf32> to tensor<?x10xf32>
-//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] 
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     output_shape [%{{.*}}, %{{.*}}, 10]
 //       CHECK:   %[[RES:.+]] = tensor.cast %[[EXPAND]]
 //  CHECK-SAME:     tensor<?x?x10xf32> to tensor<?x?x?xf32>



More information about the Mlir-commits mailing list