[Mlir-commits] [mlir] [mlir][tensor] Preserve encoding in `CollapseShapeOp::inferCollapsedType` (PR #173720)

Longsheng Mou llvmlistbot at llvm.org
Sat Dec 27 04:20:13 PST 2025


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/173720

This PR fixes `CollapseShapeOp::inferCollapsedType` so that the inferred result type preserves the same encoding as the source type.

>From f474a5ac32949cec2a32504af6a8de0bfd15e8e6 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sat, 27 Dec 2025 20:15:59 +0800
Subject: [PATCH] [mlir][tensor] Preserve encoding in
 `CollapseShapeOp::inferCollapsedType`

This PR fixes `CollapseShapeOp::inferCollapsedType` so that the inferred result type preserves the same encoding as the source type.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   |  3 ++-
 mlir/test/Dialect/Linalg/collapse-dim.mlir | 14 +++++++-------
 2 files changed, 9 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 204e9bb73e12c..6005046263f13 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2017,7 +2017,8 @@ CollapseShapeOp::inferCollapsedType(RankedTensorType type,
     currentDim += dim;
   }
 
-  return RankedTensorType::get(newShape, type.getElementType());
+  return RankedTensorType::get(newShape, type.getElementType(),
+                               type.getEncoding());
 }
 
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index c8b03f8dd5151..2995588e9abca 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -168,13 +168,13 @@ func.func @uncollapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg
 // CHECK-LABEL:   func.func @linalg_copy(
 // CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
-// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
-// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32, 1 : i64>
+// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 1 : i64> into tensor<1x2x60xf32, 1 : i64>
+// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32, 1 : i64>) outs(%[[VAL_5]] : tensor<1x2x60xf32, 3 : i64>) -> tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:           return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:         }
 



More information about the Mlir-commits mailing list