[Mlir-commits] [mlir] 670a68e - [mlir][tensor] Preserve encoding in `CollapseShapeOp::build` (#173720)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 30 19:30:20 PST 2025


Author: Longsheng Mou
Date: 2025-12-31T11:30:16+08:00
New Revision: 670a68efd19f9b2f1333447719887b7204582a54

URL: https://github.com/llvm/llvm-project/commit/670a68efd19f9b2f1333447719887b7204582a54
DIFF: https://github.com/llvm/llvm-project/commit/670a68efd19f9b2f1333447719887b7204582a54.diff

LOG: [mlir][tensor] Preserve encoding in `CollapseShapeOp::build` (#173720)

This PR updates `CollapseShapeOp::build` so that when the result type is
not explicitly provided, the inferred result type preserves the encoding
of the source tensor.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Linalg/collapse-dim.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d2b6007c628..8b10c00008865 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1253,7 +1253,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
     inferCollapsedType(RankedTensorType type, ArrayRef<AffineMap> reassociation);
     static RankedTensorType
     inferCollapsedType(RankedTensorType type,
-                       SmallVector<ReassociationIndices> reassociation);
+                       ArrayRef<ReassociationIndices> reassociation);
   }];
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 8c331f90f8a0d..72acd02d0d13d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1900,11 +1900,9 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
           applyPermutationMap(indexingMap, ArrayRef(loopBound));
       Value result;
       if (isa<MemRefType>(collapsedOpResult.getType())) {
-        MemRefType expandShapeResultType = MemRefType::get(
-            originalResultType.getShape(), originalResultType.getElementType());
         result = memref::ExpandShapeOp::create(
-            rewriter, loc, expandShapeResultType, collapsedOpResult,
-            reassociation, resultShape);
+            rewriter, loc, originalResultType, collapsedOpResult, reassociation,
+            resultShape);
       } else {
         result = tensor::ExpandShapeOp::create(
             rewriter, loc, originalResultType, collapsedOpResult, reassociation,

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bed30d29db047..a0c7e40c20a46 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1985,7 +1985,7 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
 }
 
 RankedTensorType CollapseShapeOp::inferCollapsedType(
-    RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
+    RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
   return inferCollapsedType(
       type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
                 type.getContext(), reassociation)));
@@ -2023,10 +2023,11 @@ CollapseShapeOp::inferCollapsedType(RankedTensorType type,
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<ReassociationIndices> reassociation,
                             ArrayRef<NamedAttribute> attrs) {
-  auto resultType = inferCollapsedType(
-      llvm::cast<RankedTensorType>(src.getType()),
-      getSymbolLessAffineMaps(
-          convertReassociationIndicesToExprs(b.getContext(), reassociation)));
+  auto srcType = llvm::cast<RankedTensorType>(src.getType());
+  RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
+  auto resultType =
+      RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
+                            srcType.getEncoding());
   result.addAttribute(getReassociationAttrStrName(),
                       getReassociationIndicesAttribute(b, reassociation));
   build(b, result, resultType, src, attrs);

diff  --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index c8b03f8dd5151..2995588e9abca 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -168,13 +168,13 @@ func.func @uncollapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg
 // CHECK-LABEL:   func.func @linalg_copy(
 // CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
-// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
-// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
-// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32, 1 : i64>
+// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 1 : i64> into tensor<1x2x60xf32, 1 : i64>
+// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32, 1 : i64>) outs(%[[VAL_5]] : tensor<1x2x60xf32, 3 : i64>) -> tensor<1x2x60xf32, 3 : i64>
+// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
+// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:           return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:         }
 


        


More information about the Mlir-commits mailing list