[Mlir-commits] [mlir] [mlir][tensor] Preserve encoding in `CollapseShapeOp::inferCollapsedType` (PR #173720)
Longsheng Mou
llvmlistbot at llvm.org
Sat Dec 27 05:58:49 PST 2025
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/173720
>From 913428a1c261e8a991fe7ac7e28d363a33797e07 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sat, 27 Dec 2025 20:15:59 +0800
Subject: [PATCH] [mlir][linalg] Preserve encoding in `getCollapsedOpOperand`
This PR fixes `getCollapsedOpOperand` so that the inferred result type preserves the same encoding as the source type.
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 14 +++++++++-----
mlir/test/Dialect/Linalg/collapse-dim.mlir | 14 +++++++-------
2 files changed, 16 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 8c331f90f8a0d..501101abce8db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1682,7 +1682,13 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op,
operandReassociation)
.getResult();
}
- return tensor::CollapseShapeOp::create(builder, loc, operand,
+ RankedTensorType operandType = cast<RankedTensorType>(operand.getType());
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ operandType, operandReassociation);
+ auto resultType = RankedTensorType::get(collapsedType.getShape(),
+ operandType.getElementType(),
+ operandType.getEncoding());
+ return tensor::CollapseShapeOp::create(builder, loc, resultType, operand,
operandReassociation)
.getResult();
}
@@ -1900,11 +1906,9 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
applyPermutationMap(indexingMap, ArrayRef(loopBound));
Value result;
if (isa<MemRefType>(collapsedOpResult.getType())) {
- MemRefType expandShapeResultType = MemRefType::get(
- originalResultType.getShape(), originalResultType.getElementType());
result = memref::ExpandShapeOp::create(
- rewriter, loc, expandShapeResultType, collapsedOpResult,
- reassociation, resultShape);
+ rewriter, loc, originalResultType, collapsedOpResult, reassociation,
+ resultShape);
} else {
result = tensor::ExpandShapeOp::create(
rewriter, loc, originalResultType, collapsedOpResult, reassociation,
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index c8b03f8dd5151..2995588e9abca 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -168,13 +168,13 @@ func.func @uncollapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg
// CHECK-LABEL: func.func @linalg_copy(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
-// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
-// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
-// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
-// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
-// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32, 1 : i64>
+// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 1 : i64> into tensor<1x2x60xf32, 1 : i64>
+// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x60xf32, 3 : i64>
+// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32, 1 : i64>) outs(%[[VAL_5]] : tensor<1x2x60xf32, 3 : i64>) -> tensor<1x2x60xf32, 3 : i64>
+// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x3x4x5xf32, 3 : i64>
// CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
// CHECK: }
More information about the Mlir-commits
mailing list