[Mlir-commits] [mlir] [mlir][Linalg] Allow more control in drop unit dims (PR #170104)

Lukas Sommer llvmlistbot at llvm.org
Mon Dec 1 06:21:19 PST 2025


https://github.com/sommerlukas updated https://github.com/llvm/llvm-project/pull/170104

>From 576eab032c13c6afd0afc43cbb4afa9a30b8acec Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Thu, 27 Nov 2025 14:18:45 +0000
Subject: [PATCH 1/2] [mlir][Linalg] Allow more control in drop unit dims

Extend the `ControlDropUnitDims` struct to allow users of the `linalg::dropUnitDims` function more control over the behavior of the function.

The extended struct allows users to specify functions to:
- Calculate the shape and new index map of operands, which also allows to control which operands get their unit extent dimension dropped.
- How the operands are collapsed
- How the result is expanded to the original shape

One example (and the motivation for this change) where this additional control is useful is to preserve tensor encodings and allow collapsing of tensors with an encoding, as demonstrated by the new test.

The default implementations preserve the previous behavior, existing users of the interface do not need to make any changes to their code.

Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
 .../Dialect/Linalg/Transforms/Transforms.h    | 85 ++++++++++++++++++
 .../Linalg/Transforms/DropUnitDims.cpp        | 86 ++++++-------------
 .../Dialect/Linalg/test-drop-unit-dims.mlir   | 42 +++++++++
 .../Dialect/Linalg/TestLinalgDropUnitDims.cpp | 67 ++++++++++++++-
 4 files changed, 217 insertions(+), 63 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d00183a1e16a1..d13e5ae935e00 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -524,6 +524,14 @@ struct ControlDropUnitDims {
   RankReductionStrategy rankReductionStrategy =
       RankReductionStrategy::ReassociativeReshape;
 
+  struct UnitExtentReplacementInfo {
+    AffineMap indexMap;
+    SmallVector<ReassociationIndices> reassociation;
+    SmallVector<int64_t> targetShape;
+  };
+
+  using DimensionMapping = llvm::SmallDenseMap<unsigned, unsigned>;
+
   using ControlFnTy = std::function<SmallVector<unsigned>(Operation *)>;
   ControlFnTy controlFn = [](Operation *op) {
     if (auto genericOp = dyn_cast_or_null<GenericOp>(op)) {
@@ -535,6 +543,83 @@ struct ControlDropUnitDims {
     }
     return SmallVector<unsigned>{};
   };
+
+  using ComputeOperandShapeAndMapFnTy = std::function<UnitExtentReplacementInfo(
+      const ControlDropUnitDims &, MLIRContext *, IndexingMapOpInterface,
+      OpOperand *, DimensionMapping &, ArrayRef<AffineExpr>)>;
+  ComputeOperandShapeAndMapFnTy computeOperandShapeAndMapFn =
+      [](const ControlDropUnitDims &control, MLIRContext *context,
+         IndexingMapOpInterface op, OpOperand *opOperand,
+         DimensionMapping &oldDimsToNewDimsMap,
+         ArrayRef<AffineExpr> dimReplacements) -> UnitExtentReplacementInfo {
+    auto hasCollapsibleType = [](OpOperand &operand) {
+      Type operandType = operand.get().getType();
+      if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
+        return memrefOperandType.getLayout().isIdentity();
+      }
+      if (auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
+        return tensorOperandType.getEncoding() == nullptr;
+      }
+      return false;
+    };
+    auto indexingMap = op.getMatchingIndexingMap(opOperand);
+    SmallVector<int64_t> shape = op.getStaticOperandShape(opOperand);
+    if (!hasCollapsibleType(*opOperand)) {
+      AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
+          dimReplacements, ArrayRef<AffineExpr>{}, oldDimsToNewDimsMap.size(),
+          0);
+      UnitExtentReplacementInfo info;
+      info.indexMap = newIndexingMap;
+      info.targetShape = llvm::to_vector(shape);
+      return info;
+    }
+    return control.dropUnitExtentFromOperandMetadata(
+        context, op, opOperand, oldDimsToNewDimsMap, dimReplacements);
+  };
+
+  using CollapseValueFnTy = std::function<Value(
+      const ControlDropUnitDims &, RewriterBase &, Location, Value,
+      ArrayRef<int64_t>, ArrayRef<ReassociationIndices>)>;
+  CollapseValueFnTy collapseValueFn =
+      [](const ControlDropUnitDims &control, RewriterBase &rewriter,
+         Location loc, Value operand, ArrayRef<int64_t> targetShape,
+         ArrayRef<ReassociationIndices> reassociation) -> Value {
+    return control.collapseValue(rewriter, loc, operand, targetShape,
+                                 reassociation);
+  };
+
+  using ExpandValueFnTy =
+      std::function<Value(const ControlDropUnitDims &, RewriterBase &, Location,
+                          Value, Value, ArrayRef<ReassociationIndices>)>;
+  ExpandValueFnTy expandValueFn =
+      [](const ControlDropUnitDims &control, RewriterBase &rewriter,
+         Location loc, Value result, Value origDest,
+         ArrayRef<ReassociationIndices> reassociation) -> Value {
+    return control.expandValue(rewriter, loc, result, origDest, reassociation);
+  };
+
+  /// Compute the modified metadata for an operands of operation
+  /// whose unit dims are being dropped. Return the new indexing map
+  /// to use, the shape of the operand in the replacement op
+  /// and the `reassocation` to use to go from original operand shape
+  /// to modified operand shape.
+  UnitExtentReplacementInfo
+  dropUnitExtentFromOperandMetadata(MLIRContext *, IndexingMapOpInterface,
+                                    OpOperand *, DimensionMapping &,
+                                    ArrayRef<AffineExpr>) const;
+
+  /// Collapse the given `value` so that the type matches the type of
+  /// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
+  /// set to `RankReductionStrategy::ReassociativeReshape`.
+  Value collapseValue(RewriterBase &, Location, Value, ArrayRef<int64_t>,
+                      ArrayRef<ReassociationIndices>) const;
+
+  /// Expand the given `value` so that the type matches the type of `origDest`.
+  /// The `reassociation` is used when `rankReductionStrategy` is set to
+  /// `RankReductionStrategy::ReassociativeReshape`.
+  Value expandValue(RewriterBase &rewriter, Location loc, Value result,
+                    Value origDest,
+                    ArrayRef<ReassociationIndices> reassociation) const;
 };
 
 struct DropUnitDimsResult {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 9e6c1e6036cba..f3bfdc255a5ce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -244,13 +244,9 @@ replaceUnitDimIndexOps(GenericOp genericOp,
   }
 }
 
-/// Expand the given `value` so that the type matches the type of `origDest`.
-/// The `reassociation` is used when `rankReductionStrategy` is set to
-/// `RankReductionStrategy::ReassociativeReshape`.
-static Value
-expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
-            ArrayRef<ReassociationIndices> reassociation,
-            ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
+Value ControlDropUnitDims::expandValue(
+    RewriterBase &rewriter, Location loc, Value result, Value origDest,
+    ArrayRef<ReassociationIndices> reassociation) const {
   // There are no results for memref outputs.
   auto origResultType = cast<RankedTensorType>(origDest.getType());
   if (rankReductionStrategy ==
@@ -272,13 +268,10 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
       .getResult();
 }
 
-/// Collapse the given `value` so that the type matches the type of
-/// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
-/// set to `RankReductionStrategy::ReassociativeReshape`.
-static Value collapseValue(
+Value ControlDropUnitDims::collapseValue(
     RewriterBase &rewriter, Location loc, Value operand,
-    ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
-    ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
+    ArrayRef<int64_t> targetShape,
+    ArrayRef<ReassociationIndices> reassociation) const {
   if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
     if (rankReductionStrategy ==
         ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
@@ -321,20 +314,11 @@ static Value collapseValue(
   llvm_unreachable("unsupported operand type");
 }
 
-/// Compute the modified metadata for an operands of operation
-/// whose unit dims are being dropped. Return the new indexing map
-/// to use, the shape of the operand in the replacement op
-/// and the `reassocation` to use to go from original operand shape
-/// to modified operand shape.
-struct UnitExtentReplacementInfo {
-  AffineMap indexMap;
-  SmallVector<ReassociationIndices> reassociation;
-  SmallVector<int64_t> targetShape;
-};
-static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
+ControlDropUnitDims::UnitExtentReplacementInfo
+ControlDropUnitDims::dropUnitExtentFromOperandMetadata(
     MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand,
     llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
-    ArrayRef<AffineExpr> dimReplacements) {
+    ArrayRef<AffineExpr> dimReplacements) const {
   UnitExtentReplacementInfo info;
   ReassociationIndices reassociationGroup;
   SmallVector<AffineExpr> newIndexExprs;
@@ -457,31 +441,11 @@ linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
   SmallVector<SmallVector<ReassociationIndices>> reassociations;
   SmallVector<SmallVector<int64_t>> targetShapes;
   SmallVector<bool> collapsed;
-  auto hasCollapsibleType = [](OpOperand &operand) {
-    Type operandType = operand.get().getType();
-    if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
-      return memrefOperandType.getLayout().isIdentity();
-    }
-    if (auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
-      return tensorOperandType.getEncoding() == nullptr;
-    }
-    return false;
-  };
   for (OpOperand &opOperand : op->getOpOperands()) {
     auto indexingMap = op.getMatchingIndexingMap(&opOperand);
-    SmallVector<int64_t> shape = op.getStaticOperandShape(&opOperand);
-    if (!hasCollapsibleType(opOperand)) {
-      AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
-          dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
-      newIndexingMaps.push_back(newIndexingMap);
-      targetShapes.push_back(llvm::to_vector(shape));
-      collapsed.push_back(false);
-      reassociations.push_back({});
-      continue;
-    }
-    auto replacementInfo =
-        dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand,
-                                          oldDimToNewDimMap, dimReplacements);
+    auto replacementInfo = options.computeOperandShapeAndMapFn(
+        options, rewriter.getContext(), op, &opOperand, oldDimToNewDimMap,
+        dimReplacements);
     reassociations.push_back(replacementInfo.reassociation);
     newIndexingMaps.push_back(replacementInfo.indexMap);
     targetShapes.push_back(replacementInfo.targetShape);
@@ -508,9 +472,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
       newOperands.push_back(opOperand.get());
       continue;
     }
-    newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(),
-                                        targetShapes[idx], reassociations[idx],
-                                        options.rankReductionStrategy));
+    newOperands.push_back(
+        options.collapseValueFn(options, rewriter, loc, opOperand.get(),
+                                targetShapes[idx], reassociations[idx]));
   }
 
   IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
@@ -526,9 +490,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
       resultReplacements.push_back(result);
       continue;
     }
-    Value expandedValue = expandValue(rewriter, loc, result, origDest,
-                                      reassociations[opOperandIndex],
-                                      options.rankReductionStrategy);
+    Value expandedValue =
+        options.expandValueFn(options, rewriter, loc, result, origDest,
+                              reassociations[opOperandIndex]);
     resultReplacements.push_back(expandedValue);
   }
 
@@ -686,8 +650,8 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
     }
 
     Value collapsedSource =
-        collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
-                      reassociationMap, options.rankReductionStrategy);
+        options.collapseValueFn(options, rewriter, padOp.getLoc(),
+                                padOp.getSource(), newShape, reassociationMap);
 
     auto newResultType = RankedTensorType::get(
         newResultShape, padOp.getResultType().getElementType());
@@ -714,8 +678,8 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
     }
 
     Value expandedValue =
-        expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
-                    reassociationMap, options.rankReductionStrategy);
+        options.expandValueFn(options, rewriter, padOp.getLoc(),
+                              newPadOp.getResult(), dest, reassociationMap);
     rewriter.replaceOp(padOp, expandedValue);
     return success();
   }
@@ -904,10 +868,10 @@ static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
   auto valType = cast<ShapedType>(val.getType());
   SmallVector<int64_t> collapsedShape(valType.getShape());
   collapsedShape.erase(collapsedShape.begin() + pos);
-  return collapseValue(
+  ControlDropUnitDims options;
+  return options.collapseValue(
       rewriter, val.getLoc(), val, collapsedShape,
-      getReassociationForReshapeAtDim(valType.getRank(), pos),
-      ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
+      getReassociationForReshapeAtDim(valType.getRank(), pos));
 }
 
 /// Base class for all rank reduction patterns for contraction ops
diff --git a/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir b/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
index 35eeffc1f9953..40e68ad8cbe31 100644
--- a/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
+++ b/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt -test-linalg-drop-unit-dims --split-input-file %s | FileCheck %s
+// RUN: mlir-opt -test-linalg-drop-unit-dims="preserve-encoding" --split-input-file %s | FileCheck %s --check-prefix=PRESERVE
 
 // Drop only the outermost unit dimension (controlled using a control function)
 func.func @drop_outermost_unit_dims(%arg0: tensor<1x1x42xf32>) -> tensor<1x1x42xf32> {
@@ -24,3 +25,44 @@ func.func @drop_outermost_unit_dims(%arg0: tensor<1x1x42xf32>) -> tensor<1x1x42x
 //  CHECK-SAME:       outs(%[[OUTS_RESHAPE]] :
 //       CHECK:   %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1], [2]{{\]}}
 //       CHECK:   return %[[EXPAND_SHAPE]]
+
+// -----
+
+#encoding = #test.tensor_encoding<"encoding">
+
+// Test that tensor encodings are preserved when collapsing unit dimensions
+func.func @drop_outermost_unit_dims_with_encoding(%arg0: tensor<1x1x42xf32, #encoding>) -> tensor<1x1x42xf32, #encoding> {
+  %0 = tensor.empty() : tensor<1x1x42xf32, #encoding>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                     affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+    iterator_types = ["parallel", "parallel", "parallel"]}
+    ins(%arg0 : tensor<1x1x42xf32, #encoding>)
+    outs(%0 : tensor<1x1x42xf32, #encoding>) {
+      ^bb0(%b0: f32, %b1 : f32):
+        %2 = arith.addf %b0, %b1 : f32
+        linalg.yield %2 : f32
+    } -> tensor<1x1x42xf32, #encoding>
+  return %1 : tensor<1x1x42xf32, #encoding>
+}
+// Without preserve-encoding flag, encoded tensors are not collapsed
+// CHECK-LABEL: func @drop_outermost_unit_dims_with_encoding
+//       CHECK:   linalg.generic
+//   CHECK-NOT:   tensor.collapse_shape
+//   CHECK-NOT:   tensor.expand_shape
+
+// With preserve-encoding flag, encodings are preserved through collapse/expand
+// PRESERVE: affine_map<(d0, d1) -> (d0, d1)>
+// PRESERVE-LABEL: func @drop_outermost_unit_dims_with_encoding
+//  PRESERVE-SAME:     %[[ARG0:.+]]: tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>
+//       PRESERVE:   %[[OUTS:.+]] = tensor.empty() : tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>
+//       PRESERVE:   %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2]{{\]}}
+//  PRESERVE-SAME:       : tensor<1x1x42xf32, #test.tensor_encoding<"encoding">> into tensor<1x42xf32, #test.tensor_encoding<"encoding">>
+//       PRESERVE:   %[[OUTS_RESHAPE:.+]] = tensor.collapse_shape %[[OUTS]] {{\[}}[0, 1], [2]{{\]}}
+//  PRESERVE-SAME:       : tensor<1x1x42xf32, #test.tensor_encoding<"encoding">> into tensor<1x42xf32, #test.tensor_encoding<"encoding">>
+//       PRESERVE:   %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
+//  PRESERVE-SAME:       ins(%[[ARG0_RESHAPE]] : tensor<1x42xf32, #test.tensor_encoding<"encoding">>)
+//  PRESERVE-SAME:       outs(%[[OUTS_RESHAPE]] : tensor<1x42xf32, #test.tensor_encoding<"encoding">>)
+//       PRESERVE:   %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1], [2]{{\]}}
+//  PRESERVE-SAME:       : tensor<1x42xf32, #test.tensor_encoding<"encoding">> into tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>
+//       PRESERVE:   return %[[EXPAND_SHAPE]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
index 402ce154c0848..20d0d61d0211f 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
@@ -34,13 +34,72 @@ LogicalResult dropOutermostUnitDims(RewriterBase &rewriter,
   return success();
 }
 
+LogicalResult dropOutermostUnitDimsWithEncoding(RewriterBase &rewriter,
+                                                linalg::GenericOp genericOp) {
+  linalg::ControlDropUnitDims options;
+
+  options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; };
+  options.computeOperandShapeAndMapFn =
+      [](const linalg::ControlDropUnitDims &control, MLIRContext *context,
+         IndexingMapOpInterface op, OpOperand *opOperand,
+         linalg::ControlDropUnitDims::DimensionMapping &oldDimsToNewDimsMap,
+         ArrayRef<AffineExpr> dimReplacements)
+      -> linalg::ControlDropUnitDims::UnitExtentReplacementInfo {
+    auto isCollapsible = [](Type ty) { return isa<RankedTensorType>(ty); };
+    auto indexingMap = op.getMatchingIndexingMap(opOperand);
+    SmallVector<int64_t> shape = op.getStaticOperandShape(opOperand);
+    if (!isCollapsible(opOperand->get().getType())) {
+      AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
+          dimReplacements, ArrayRef<AffineExpr>{}, oldDimsToNewDimsMap.size(),
+          0);
+      linalg::ControlDropUnitDims::UnitExtentReplacementInfo info;
+      info.indexMap = newIndexingMap;
+      info.targetShape = llvm::to_vector(shape);
+      return info;
+    }
+    return control.dropUnitExtentFromOperandMetadata(
+        context, op, opOperand, oldDimsToNewDimsMap, dimReplacements);
+  };
+
+  // Preserve encoding when collapsing
+  options.collapseValueFn =
+      [](const linalg::ControlDropUnitDims &control, RewriterBase &rewriter,
+         Location loc, Value operand, ArrayRef<int64_t> targetShape,
+         ArrayRef<ReassociationIndices> reassociation) -> Value {
+    auto tensorType = cast<RankedTensorType>(operand.getType());
+    assert(control.rankReductionStrategy ==
+               linalg::ControlDropUnitDims::RankReductionStrategy::
+                   ReassociativeReshape &&
+           "unexpected rank reduction strategy");
+    auto targetType = RankedTensorType::get(
+        targetShape, tensorType.getElementType(), tensorType.getEncoding());
+    return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand,
+                                           reassociation);
+  };
+
+  FailureOr<linalg::DropUnitDimsResult> result =
+      linalg::dropUnitDims(rewriter, genericOp, options);
+  if (failed(result)) {
+    return failure();
+  }
+  rewriter.replaceOp(genericOp, result->replacements);
+  return success();
+}
+
 struct TestLinalgDropUnitDims
     : public PassWrapper<TestLinalgDropUnitDims, OperationPass<func::FuncOp>> {
 
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDropUnitDims)
 
   TestLinalgDropUnitDims() = default;
-  TestLinalgDropUnitDims(const TestLinalgDropUnitDims &pass) = default;
+  TestLinalgDropUnitDims(const TestLinalgDropUnitDims &pass)
+      : PassWrapper(pass) {}
+
+  Option<bool> preserveEncoding{
+      *this, "preserve-encoding",
+      llvm::cl::desc(
+          "Preserve tensor encodings when collapsing unit dimensions"),
+      llvm::cl::init(false)};
 
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<linalg::LinalgDialect>();
@@ -63,7 +122,11 @@ struct TestLinalgDropUnitDims
 
     for (auto genericOp : genericOps) {
       rewriter.setInsertionPoint(genericOp);
-      (void)dropOutermostUnitDims(rewriter, genericOp);
+      if (preserveEncoding) {
+        (void)dropOutermostUnitDimsWithEncoding(rewriter, genericOp);
+      } else {
+        (void)dropOutermostUnitDims(rewriter, genericOp);
+      }
     }
   }
 };

>From e3a2848e47525ca1b9564dbda6fe3ad44d2bc521 Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Mon, 1 Dec 2025 14:20:53 +0000
Subject: [PATCH 2/2] Address PR feedback

Flip logic to enable simpler early exit.

Rename option in test pass to clarify behavior.

Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
 .../Dialect/Linalg/Transforms/Transforms.h    | 19 +++++++++----------
 .../Dialect/Linalg/test-drop-unit-dims.mlir   |  2 +-
 .../Dialect/Linalg/TestLinalgDropUnitDims.cpp |  4 ++--
 3 files changed, 12 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d13e5ae935e00..18aa863775728 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -564,17 +564,16 @@ struct ControlDropUnitDims {
     };
     auto indexingMap = op.getMatchingIndexingMap(opOperand);
     SmallVector<int64_t> shape = op.getStaticOperandShape(opOperand);
-    if (!hasCollapsibleType(*opOperand)) {
-      AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
-          dimReplacements, ArrayRef<AffineExpr>{}, oldDimsToNewDimsMap.size(),
-          0);
-      UnitExtentReplacementInfo info;
-      info.indexMap = newIndexingMap;
-      info.targetShape = llvm::to_vector(shape);
-      return info;
+    if (hasCollapsibleType(*opOperand)) {
+      return control.dropUnitExtentFromOperandMetadata(
+          context, op, opOperand, oldDimsToNewDimsMap, dimReplacements);
     }
-    return control.dropUnitExtentFromOperandMetadata(
-        context, op, opOperand, oldDimsToNewDimsMap, dimReplacements);
+    AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
+        dimReplacements, ArrayRef<AffineExpr>{}, oldDimsToNewDimsMap.size(), 0);
+    UnitExtentReplacementInfo info;
+    info.indexMap = newIndexingMap;
+    info.targetShape = llvm::to_vector(shape);
+    return info;
   };
 
   using CollapseValueFnTy = std::function<Value(
diff --git a/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir b/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
index 40e68ad8cbe31..5a8c2ce3c84ed 100644
--- a/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
+++ b/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt -test-linalg-drop-unit-dims --split-input-file %s | FileCheck %s
-// RUN: mlir-opt -test-linalg-drop-unit-dims="preserve-encoding" --split-input-file %s | FileCheck %s --check-prefix=PRESERVE
+// RUN: mlir-opt -test-linalg-drop-unit-dims="collapse-encoded" --split-input-file %s | FileCheck %s --check-prefix=PRESERVE
 
 // Drop only the outermost unit dimension (controlled using a control function)
 func.func @drop_outermost_unit_dims(%arg0: tensor<1x1x42xf32>) -> tensor<1x1x42xf32> {
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
index 20d0d61d0211f..6089360135530 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
@@ -96,9 +96,9 @@ struct TestLinalgDropUnitDims
       : PassWrapper(pass) {}
 
   Option<bool> preserveEncoding{
-      *this, "preserve-encoding",
+      *this, "collapse-encoded",
       llvm::cl::desc(
-          "Preserve tensor encodings when collapsing unit dimensions"),
+          "Collapse tensors with encodings and unit extend dimensions"),
       llvm::cl::init(false)};
 
   void getDependentDialects(DialectRegistry &registry) const override {



More information about the Mlir-commits mailing list