[Mlir-commits] [mlir] [mlir][linalg] Preserve encoding in `getCollapsedOpOperand` (PR #173720)

Longsheng Mou llvmlistbot at llvm.org
Sat Dec 27 22:42:52 PST 2025


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/173720

>From eb14abcd09bebe6805a7e627b36f69830654a60d Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sat, 27 Dec 2025 20:15:59 +0800
Subject: [PATCH] [mlir][tensor] Preserve encoding in `CollapseShapeOp::build`

This PR updates `CollapseShapeOp::build` so that when the result type is not explicitly provided, the inferred result type preserves the encoding of the source tensor.
---
 mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td   |  2 +-
 .../Linalg/Transforms/ElementwiseOpFusion.cpp      |  6 ++----
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp           | 11 ++++++-----
 mlir/test/Dialect/Linalg/collapse-dim.mlir         | 14 +++++++-------
 4 files changed, 16 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d2b6007c628..8b10c00008865 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1253,7 +1253,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
     inferCollapsedType(RankedTensorType type, ArrayRef<AffineMap> reassociation);
     static RankedTensorType
     inferCollapsedType(RankedTensorType type,
-                       SmallVector<ReassociationIndices> reassociation);
+                       ArrayRef<ReassociationIndices> reassociation);
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 8c331f90f8a0d..72acd02d0d13d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1900,11 +1900,9 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
           applyPermutationMap(indexingMap, ArrayRef(loopBound));
       Value result;
       if (isa<MemRefType>(collapsedOpResult.getType())) {
-        MemRefType expandShapeResultType = MemRefType::get(
-            originalResultType.getShape(), originalResultType.getElementType());
         result = memref::ExpandShapeOp::create(
-            rewriter, loc, expandShapeResultType, collapsedOpResult,
-            reassociation, resultShape);
+            rewriter, loc, originalResultType, collapsedOpResult, reassociation,
+            resultShape);
       } else {
         result = tensor::ExpandShapeOp::create(
             rewriter, loc, originalResultType, collapsedOpResult, reassociation,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 204e9bb73e12c..afd5414a190e5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1985,7 +1985,7 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
 }
 
 RankedTensorType CollapseShapeOp::inferCollapsedType(
-    RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
+    RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
   return inferCollapsedType(
       type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
                 type.getContext(), reassociation)));
@@ -2023,10 +2023,11 @@ CollapseShapeOp::inferCollapsedType(RankedTensorType type,
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<ReassociationIndices> reassociation,
                             ArrayRef<NamedAttribute> attrs) {
-  auto resultType = inferCollapsedType(
-      llvm::cast<RankedTensorType>(src.getType()),
-      getSymbolLessAffineMaps(
-          convertReassociationIndicesToExprs(b.getContext(), reassociation)));
+  auto srcType = llvm::cast<RankedTensorType>(src.getType());
+  RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
+  auto resultType =
+      RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
+                            srcType.getEncoding());
   result.addAttribute(getReassociationAttrStrName(),
                       getReassociationIndicesAttribute(b, reassociation));
   build(b, result, resultType, src, attrs);
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index c8b03f8dd5151..2995588e9abca 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -168,13 +168,13 @@ func.func @uncollapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg
 // CHECK-LABEL:   func.func @linalg_copy(
 // CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
-// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
-// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32, 1 : i64>
+// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 1 : i64> into tensor<1x2x60xf32, 1 : i64>
+// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32, 1 : i64>) outs(%[[VAL_5]] : tensor<1x2x60xf32, 3 : i64>) -> tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:           return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:         }
 



More information about the Mlir-commits mailing list