[Mlir-commits] [mlir] [mlir][tensor] Preserve encoding in `CollapseShapeOp::inferCollapsedType` (PR #173720)

Longsheng Mou llvmlistbot at llvm.org
Sat Dec 27 05:58:49 PST 2025


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/173720

>From 913428a1c261e8a991fe7ac7e28d363a33797e07 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sat, 27 Dec 2025 20:15:59 +0800
Subject: [PATCH] [mlir][linalg] Preserve encoding in `getCollapsedOpOperand`

This PR fixes `getCollapsedOpOperand` so that the inferred result type preserves the same encoding as the source type.
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp      | 14 +++++++++-----
 mlir/test/Dialect/Linalg/collapse-dim.mlir         | 14 +++++++-------
 2 files changed, 16 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 8c331f90f8a0d..501101abce8db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1682,7 +1682,13 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op,
                                            operandReassociation)
         .getResult();
   }
-  return tensor::CollapseShapeOp::create(builder, loc, operand,
+  RankedTensorType operandType = cast<RankedTensorType>(operand.getType());
+  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+      operandType, operandReassociation);
+  auto resultType = RankedTensorType::get(collapsedType.getShape(),
+                                          operandType.getElementType(),
+                                          operandType.getEncoding());
+  return tensor::CollapseShapeOp::create(builder, loc, resultType, operand,
                                          operandReassociation)
       .getResult();
 }
@@ -1900,11 +1906,9 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
           applyPermutationMap(indexingMap, ArrayRef(loopBound));
       Value result;
       if (isa<MemRefType>(collapsedOpResult.getType())) {
-        MemRefType expandShapeResultType = MemRefType::get(
-            originalResultType.getShape(), originalResultType.getElementType());
         result = memref::ExpandShapeOp::create(
-            rewriter, loc, expandShapeResultType, collapsedOpResult,
-            reassociation, resultShape);
+            rewriter, loc, originalResultType, collapsedOpResult, reassociation,
+            resultShape);
       } else {
         result = tensor::ExpandShapeOp::create(
             rewriter, loc, originalResultType, collapsedOpResult, reassociation,
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index c8b03f8dd5151..2995588e9abca 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -168,13 +168,13 @@ func.func @uncollapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg
 // CHECK-LABEL:   func.func @linalg_copy(
 // CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
-// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
-// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32, 1 : i64>
+// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 1 : i64> into tensor<1x2x60xf32, 1 : i64>
+// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32, 1 : i64>) outs(%[[VAL_5]] : tensor<1x2x60xf32, 3 : i64>) -> tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:           return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:         }
 



More information about the Mlir-commits mailing list