[Mlir-commits] [mlir] [mlir][tensor] Fold EmptyOp & Collapse/ExpandShapeOp to EmptyOp (PR #175437)
Maya Amrami
llvmlistbot at llvm.org
Sun Jan 11 06:22:12 PST 2026
https://github.com/amrami created https://github.com/llvm/llvm-project/pull/175437
None
>From 9e16276cc1b82625a3225f5efe3af306f2ee2083 Mon Sep 17 00:00:00 2001
From: Maya Amrami <maya.amrami at mobileye.com>
Date: Sun, 11 Jan 2026 16:21:21 +0200
Subject: [PATCH] [mlir][tensor] Fold EmptyOp & Collapse/ExpandShapeOp to
EmptyOp
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 23 +++++++++++++++-
mlir/test/Dialect/Linalg/canonicalize.mlir | 8 +++---
mlir/test/Dialect/Tensor/canonicalize.mlir | 32 +++++++++++++++++++---
3 files changed, 54 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a0c7e40c20a46..d085e7cb72c5c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1244,12 +1244,33 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
}
};
+template <typename T>
+struct FoldEmptyTensorWithCollapseExpandOp : public OpRewritePattern<T> {
+ using OpRewritePattern<T>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(T op,
+ PatternRewriter &rewriter) const override {
+ auto producer = op.getSrc().template getDefiningOp<EmptyOp>();
+ if (!producer)
+ return failure();
+ if (!producer.getType().hasStaticShape())
+ return failure();
+
+ auto resultType = cast<RankedTensorType>(op.getResultType());
+ rewriter.replaceOpWithNewOp<EmptyOp>(op, resultType.getShape(),
+ resultType.getElementType());
+ return success();
+ }
+};
+
} // namespace
void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
- ReplaceEmptyTensorStaticShapeDims>(context);
+ ReplaceEmptyTensorStaticShapeDims,
+ FoldEmptyTensorWithCollapseExpandOp<CollapseShapeOp>,
+ FoldEmptyTensorWithCollapseExpandOp<ExpandShapeOp>>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index f4020ede4854e..f2a708b1d747d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -389,9 +389,9 @@ func.func @fold_linalg_index_memref(%0: memref<1x?xi32>, %1: memref<1x?xi32>) {
func.func @fold_fill_reshape() -> tensor<6x4xf32> {
%zero = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<1x2x3x4xf32>
- // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
+ // CHECK: %[[COLLAPSED_EMPTY:.+]] = tensor.empty()
// CHECK-NEXT: %[[FILL:.+]] = linalg.fill ins(%cst : f32)
- // CHECK-SAME: outs(%[[COLLAPSE]] : tensor<6x4xf32>)
+ // CHECK-SAME: outs(%[[COLLAPSED_EMPTY]] : tensor<6x4xf32>)
%fill = linalg.fill ins(%zero : f32) outs(%empty : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
%reshape = tensor.collapse_shape %fill [[0, 1, 2], [3]]
: tensor<1x2x3x4xf32> into tensor<6x4xf32>
@@ -512,7 +512,7 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
// -----
// CHECK-LABEL: func @no_fold_fill_like_memref
-// CHECK-NEXT: linalg.generic
+// CHECK-NEXT: linalg.generic
func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32) {
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
@@ -528,7 +528,7 @@ func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32)
// -----
// CHECK-LABEL: func @no_fold_fill_like_tensor
-// CHECK-NEXT: linalg.generic
+// CHECK-NEXT: linalg.generic
func.func @no_fold_fill_like_tensor(%in_out : tensor<4x16xf32>, %fill_val : f32) -> tensor<4x16xf32> {
%result = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 95c5b8c91edf5..fe0ea58282149 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2218,6 +2218,30 @@ func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> {
// -----
+func.func @fold_empty_tensor_with_collapse() -> tensor<12xf32> {
+ %0 = tensor.empty() : tensor<1x12xf32>
+ %1 = tensor.collapse_shape %0 [[0, 1]]: tensor<1x12xf32> into tensor<12xf32>
+ return %1 : tensor<12xf32>
+}
+
+// CHECK: func @fold_empty_tensor_with_collapse()
+// CHECK: %[[T0:.+]] = tensor.empty() : tensor<12xf32>
+// CHECK: return %[[T0]] : tensor<12xf32>
+
+// -----
+
+func.func @fold_empty_tensor_with_expand() -> tensor<1x12xf32> {
+ %0 = tensor.empty() : tensor<12xf32>
+ %1 = tensor.expand_shape %0 [[0, 1]] output_shape [1, 12] : tensor<12xf32> into tensor<1x12xf32>
+ return %1 : tensor<1x12xf32>
+}
+
+// CHECK: func @fold_empty_tensor_with_expand()
+// CHECK: %[[T0:.+]] = tensor.empty() : tensor<1x12xf32>
+// CHECK: return %[[T0]] : tensor<1x12xf32>
+
+// -----
+
func.func private @some_use(%i : index, %j : index)
// CHECK-LABEL: func @empty_tensor_canonicalize
@@ -2523,8 +2547,8 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
-> tensor<10x1x10xf32> {
- %c1 = arith.constant 1 : index
- %c10 = arith.constant 10 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
%0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
: tensor<?x?xf32> into tensor<?x?x?xf32>
@@ -2549,7 +2573,7 @@ func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
// CHECK-LABEL: func.func @sink_expand_of_cast
// CHECK-DAG: %[[C10:.*]] = arith.constant 10
// CHECK-DAG: %[[C1:.*]] = arith.constant 1
-// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: output_shape [%[[C10]], %[[C1]], 10]
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
// CHECK: return %[[RES]]
@@ -2567,7 +2591,7 @@ func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index,
// CHECK-LABEL: func.func @partial_sink_expand_of_cast
// CHECK: %[[CAST:.+]] = tensor.cast
// CHECK-SAME: tensor<10x10xf32> to tensor<?x10xf32>
-// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: output_shape [%{{.*}}, %{{.*}}, 10]
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
// CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>
More information about the Mlir-commits
mailing list