[Mlir-commits] [mlir] 670a68e - [mlir][tensor] Preserve encoding in `CollapseShapeOp::build` (#173720)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 30 19:30:20 PST 2025
Author: Longsheng Mou
Date: 2025-12-31T11:30:16+08:00
New Revision: 670a68efd19f9b2f1333447719887b7204582a54
URL: https://github.com/llvm/llvm-project/commit/670a68efd19f9b2f1333447719887b7204582a54
DIFF: https://github.com/llvm/llvm-project/commit/670a68efd19f9b2f1333447719887b7204582a54.diff
LOG: [mlir][tensor] Preserve encoding in `CollapseShapeOp::build` (#173720)
This PR updates `CollapseShapeOp::build` so that when the result type is
not explicitly provided, the inferred result type preserves the encoding
of the source tensor.
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Linalg/collapse-dim.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d2b6007c628..8b10c00008865 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1253,7 +1253,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
inferCollapsedType(RankedTensorType type, ArrayRef<AffineMap> reassociation);
static RankedTensorType
inferCollapsedType(RankedTensorType type,
- SmallVector<ReassociationIndices> reassociation);
+ ArrayRef<ReassociationIndices> reassociation);
}];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 8c331f90f8a0d..72acd02d0d13d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1900,11 +1900,9 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
applyPermutationMap(indexingMap, ArrayRef(loopBound));
Value result;
if (isa<MemRefType>(collapsedOpResult.getType())) {
- MemRefType expandShapeResultType = MemRefType::get(
- originalResultType.getShape(), originalResultType.getElementType());
result = memref::ExpandShapeOp::create(
- rewriter, loc, expandShapeResultType, collapsedOpResult,
- reassociation, resultShape);
+ rewriter, loc, originalResultType, collapsedOpResult, reassociation,
+ resultShape);
} else {
result = tensor::ExpandShapeOp::create(
rewriter, loc, originalResultType, collapsedOpResult, reassociation,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bed30d29db047..a0c7e40c20a46 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1985,7 +1985,7 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
}
RankedTensorType CollapseShapeOp::inferCollapsedType(
- RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
+ RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
return inferCollapsedType(
type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
type.getContext(), reassociation)));
@@ -2023,10 +2023,11 @@ CollapseShapeOp::inferCollapsedType(RankedTensorType type,
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
- auto resultType = inferCollapsedType(
- llvm::cast<RankedTensorType>(src.getType()),
- getSymbolLessAffineMaps(
- convertReassociationIndicesToExprs(b.getContext(), reassociation)));
+ auto srcType = llvm::cast<RankedTensorType>(src.getType());
+ RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
+ auto resultType =
+ RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
+ srcType.getEncoding());
result.addAttribute(getReassociationAttrStrName(),
getReassociationIndicesAttribute(b, reassociation));
build(b, result, resultType, src, attrs);
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index c8b03f8dd5151..2995588e9abca 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -168,13 +168,13 @@ func.func @uncollapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg
// CHECK-LABEL: func.func @linalg_copy(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
-// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
-// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
-// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
-// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
-// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32, 1 : i64>
+// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 1 : i64> into tensor<1x2x60xf32, 1 : i64>
+// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x60xf32, 3 : i64>
+// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32, 1 : i64>) outs(%[[VAL_5]] : tensor<1x2x60xf32, 3 : i64>) -> tensor<1x2x60xf32, 3 : i64>
+// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x3x4x5xf32, 3 : i64>
// CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
// CHECK: }
More information about the Mlir-commits
mailing list