[Mlir-commits] [mlir] [mlir][Linalg] Allow more control in drop unit dims (PR #171796)
Lukas Sommer
llvmlistbot at llvm.org
Thu Dec 11 02:31:09 PST 2025
https://github.com/sommerlukas created https://github.com/llvm/llvm-project/pull/171796
Extend the ControlDropUnitDims struct to allow users of the `linalg::dropUnitDims` function more control over the behavior of the function.
The extended struct allows users to specify functions to control how the operands are collapsed and how the result is expanded to the original shape.
One example (and the motivation for this change) where this additional control is useful is to allow collapsing of tensors with an encoding, as demonstrated by the new test.
This is a breaking change. The new default behavior changes to abort the transformation if one of the operands cannot be collapsed or if the result cannot be expanded. This is the case for `memref` with non-identity layout and `tensor`s with an encoding.
>From 016e6220da42aeb86ee6218ef63ba6480d1fc673 Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Tue, 9 Dec 2025 14:15:27 +0000
Subject: [PATCH 1/2] [mlir][Linalg] Allow more control in drop unit dims
Extend the `ControlDropUnitDims` struct to allow users of the `linalg::dropUnitDims` function more control over the behavior of the function.
The extended struct allows users to specify functions to control how the operands are collapsed and how the result is expanded to the original shape.
One example (and the motivation for this change) where this additional control is useful is to allow collapsing of tensors with an encoding, as demonstrated by the new test.
This is a breaking change. The default behavior changes to abort the transformation if one of the operands cannot be collapsed or if the result cannot be expanded. This is the case for memref with non-identity layout and tensors with an encoding.
Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
.../Dialect/Linalg/Transforms/Transforms.h | 97 +++++++++++++
.../Linalg/Transforms/DropUnitDims.cpp | 134 ++++++++++--------
.../Dialect/Linalg/drop-unit-extent-dims.mlir | 38 +++--
.../Dialect/Linalg/test-drop-unit-dims.mlir | 99 ++++++++++++-
.../Dialect/Linalg/TestLinalgDropUnitDims.cpp | 65 ++++++++-
5 files changed, 350 insertions(+), 83 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d00183a1e16a1..e47ed3f0873ad 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -516,6 +516,28 @@ LogicalResult vectorizeOpPrecondition(Operation *op,
using LinalgLoops = SmallVector<Operation *, 4>;
+// Forward declaration
+struct ControlDropUnitDims;
+
+/// Collapse the given \p value to \p targetShape. The \p reassociation is used
+/// when `rankReductionStrategy` of \p control is set to
+/// `RankReductionStrategy::ReassociativeReshape`. Will return failure if the
+/// operand has memref type with a non-identity layout or tensor type with an
+/// encoding.
+FailureOr<Value> collapseValue(RewriterBase &rewriter, Location loc,
+ Value operand, ArrayRef<int64_t> targetShape,
+ ArrayRef<ReassociationIndices> reassociation,
+ const ControlDropUnitDims &control);
+
+/// Expand the given \p value so that the type matches the type of \p origDest.
+/// The \p reassociation is used when `rankReductionStrategy` of \p control is
+/// set to `RankReductionStrategy::ReassociativeReshape`. Will return failure if
+/// the original destination has tensor type with an encoding.
+FailureOr<Value> expandValue(RewriterBase &rewriter, Location loc, Value result,
+ Value origDest,
+ ArrayRef<ReassociationIndices> reassociation,
+ const ControlDropUnitDims &control);
+
/// Transformation to drop unit-extent dimensions from `linalg.generic`
/// operations.
struct ControlDropUnitDims {
@@ -524,7 +546,19 @@ struct ControlDropUnitDims {
RankReductionStrategy rankReductionStrategy =
RankReductionStrategy::ReassociativeReshape;
+ /// Instances of this type are used to control which dimensions of an operand
+ /// are considered for dropping unit extent dimensions. The parameter to the
+ /// function is the operation itself, the expected return is a list of
+ /// dimensions to consider for dropping unit extent dimensions. If the
+ /// operation should not be have any dimensions dropped, implementations
+ /// should return an empty list.
using ControlFnTy = std::function<SmallVector<unsigned>(Operation *)>;
+
+ /// Function to control which dimensions, if any, are to be considered for
+ /// dropping unit extent dimensions. The default behavior is to consider all
+ /// dimensions of a \c linalg.generic or \c tensor.pad operation for dropping.
+ /// Users of the \ref dropUnitDims interface can override the default behavior
+ /// by setting this member to their own implementation.
ControlFnTy controlFn = [](Operation *op) {
if (auto genericOp = dyn_cast_or_null<GenericOp>(op)) {
return llvm::to_vector(llvm::seq<unsigned>(0, genericOp.getNumLoops()));
@@ -535,6 +569,58 @@ struct ControlDropUnitDims {
}
return SmallVector<unsigned>{};
};
+
+ /// Instances of this type are used to control how operand values are
+ /// collapsed after dropping unit extent dimensions. Next to the control
+ /// struct, rewriter and location, the function receives the operand value to
+ /// collapse, the new target shape and how old dimensions should be grouped.
+ /// The function needs to insert the necessary operations to collapse the
+ /// operand to the target shape and returns the new operand value.
+ /// If the operand should not be collapsed, the function should return
+ /// failure, leading to the transformation to be aborted.
+ using CollapseFnTy = std::function<FailureOr<Value>(
+ RewriterBase &, Location, Value, ArrayRef<int64_t>,
+ ArrayRef<ReassociationIndices>, const ControlDropUnitDims &)>;
+
+ /// Function to control how operands are collapsed into their new target shape
+ /// after dropping unit extent dimensions. For the default behavior
+ /// \see linalg::collapseValue.
+ /// Users of the \ref dropUnitDims interface can override the default behavior
+ /// by setting this member to their own implementation.
+ CollapseFnTy collapseFn =
+ [](RewriterBase &rewriter, Location loc, Value operand,
+ ArrayRef<int64_t> targetShape,
+ ArrayRef<ReassociationIndices> reassociation,
+ const ControlDropUnitDims &control) -> FailureOr<Value> {
+ return linalg::collapseValue(rewriter, loc, operand, targetShape,
+ reassociation, control);
+ };
+
+ /// Instances of this type are used to control how result values are expanded
+ /// into their original shape after dropping unit extent dimensions. Next to
+ /// the control construct, rewriter and location, the function recieves the
+ /// result value, the original value to replace and and information on how the
+ /// new dimensions were grouped.
+ /// The function needs to insert the necessary operations to expand the
+ /// result to the original shape and returns the new result value.
+ /// If the result should not be expanded, the function should return
+ /// failure, leading to the transformation to be aborted.
+ using ExpandFnTy = std::function<FailureOr<Value>(
+ RewriterBase &, Location, Value, Value, ArrayRef<ReassociationIndices>,
+ const ControlDropUnitDims &)>;
+
+ /// Function to control how results are expanded into their original shape
+ /// after dropping unit extent dimensions. The default behavior
+ /// \see linalg::expandValue.
+ /// Users of the \ref dropUnitDims interface can override the default behavior
+ /// by setting this member to their own implementation.
+ ExpandFnTy expandFn =
+ [](RewriterBase &rewriter, Location loc, Value result, Value origDest,
+ ArrayRef<ReassociationIndices> reassociation,
+ const ControlDropUnitDims &control) -> FailureOr<Value> {
+ return linalg::expandValue(rewriter, loc, result, origDest, reassociation,
+ control);
+ };
};
struct DropUnitDimsResult {
@@ -546,10 +632,21 @@ using DroppedUnitDimsBuilder = std::function<IndexingMapOpInterface(
ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
const llvm::SmallDenseSet<unsigned> &droppedDims)>;
+/// Drop unit extent dimensions from the \p op and its operands.
+/// The transformation is aborted if unit dimensions cannot be dropped from any
+/// of the operands. Note that this function may insert trivially dead
+/// operations if the transformation is aborted and should therefore not be
+/// called from greedy drivers.
FailureOr<DropUnitDimsResult>
dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
const ControlDropUnitDims &options);
+
+/// Drop unit extent dimensions from the \p genericOp and its operands.
+/// The transformation is aborted if unit dimensions cannot be dropped from any
+/// of the operands. Note that this function may insert trivially dead
+/// operations if the transformation is aborted and should therefore not be
+/// called from greedy drivers.
FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
GenericOp genericOp,
const ControlDropUnitDims &options);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 9e6c1e6036cba..0fb7ca08b0ace 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -244,16 +244,19 @@ replaceUnitDimIndexOps(GenericOp genericOp,
}
}
-/// Expand the given `value` so that the type matches the type of `origDest`.
-/// The `reassociation` is used when `rankReductionStrategy` is set to
-/// `RankReductionStrategy::ReassociativeReshape`.
-static Value
-expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
- ArrayRef<ReassociationIndices> reassociation,
- ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
+FailureOr<Value>
+linalg::expandValue(RewriterBase &rewriter, Location loc, Value result,
+ Value origDest,
+ ArrayRef<ReassociationIndices> reassociation,
+ const ControlDropUnitDims &control) {
// There are no results for memref outputs.
auto origResultType = cast<RankedTensorType>(origDest.getType());
- if (rankReductionStrategy ==
+ origResultType.dump();
+ if (origResultType.getEncoding() != nullptr) {
+ // Do not expand tensors with encoding.
+ return failure();
+ }
+ if (control.rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
unsigned rank = origResultType.getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
@@ -264,7 +267,7 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
loc, result, origDest, offsets, sizes, strides);
}
- assert(rankReductionStrategy ==
+ assert(control.rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
return tensor::ExpandShapeOp::create(rewriter, loc, origResultType, result,
@@ -272,15 +275,17 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
.getResult();
}
-/// Collapse the given `value` so that the type matches the type of
-/// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
-/// set to `RankReductionStrategy::ReassociativeReshape`.
-static Value collapseValue(
- RewriterBase &rewriter, Location loc, Value operand,
- ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
- ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
+FailureOr<Value>
+linalg::collapseValue(RewriterBase &rewriter, Location loc, Value operand,
+ ArrayRef<int64_t> targetShape,
+ ArrayRef<ReassociationIndices> reassociation,
+ const ControlDropUnitDims &control) {
if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
- if (rankReductionStrategy ==
+ if (!memrefType.getLayout().isIdentity()) {
+ // Do not collapse memrefs with a non-identity layout.
+ return failure();
+ }
+ if (control.rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
FailureOr<Value> rankReducingExtract =
memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
@@ -290,17 +295,22 @@ static Value collapseValue(
}
assert(
- rankReductionStrategy ==
+ control.rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
MemRefLayoutAttrInterface layout;
auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
layout, memrefType.getMemorySpace());
return memref::CollapseShapeOp::create(rewriter, loc, targetType, operand,
- reassociation);
+ reassociation)
+ .getResult();
}
if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
- if (rankReductionStrategy ==
+ if (tensorType.getEncoding() != nullptr) {
+ // Do not collapse tensors with an encoding.
+ return failure();
+ }
+ if (control.rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
FailureOr<Value> rankReducingExtract =
tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
@@ -310,13 +320,14 @@ static Value collapseValue(
}
assert(
- rankReductionStrategy ==
+ control.rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
auto targetType =
RankedTensorType::get(targetShape, tensorType.getElementType());
return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand,
- reassociation);
+ reassociation)
+ .getResult();
}
llvm_unreachable("unsupported operand type");
}
@@ -457,28 +468,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
SmallVector<SmallVector<ReassociationIndices>> reassociations;
SmallVector<SmallVector<int64_t>> targetShapes;
SmallVector<bool> collapsed;
- auto hasCollapsibleType = [](OpOperand &operand) {
- Type operandType = operand.get().getType();
- if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
- return memrefOperandType.getLayout().isIdentity();
- }
- if (auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
- return tensorOperandType.getEncoding() == nullptr;
- }
- return false;
- };
for (OpOperand &opOperand : op->getOpOperands()) {
auto indexingMap = op.getMatchingIndexingMap(&opOperand);
- SmallVector<int64_t> shape = op.getStaticOperandShape(&opOperand);
- if (!hasCollapsibleType(opOperand)) {
- AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
- dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
- newIndexingMaps.push_back(newIndexingMap);
- targetShapes.push_back(llvm::to_vector(shape));
- collapsed.push_back(false);
- reassociations.push_back({});
- continue;
- }
auto replacementInfo =
dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand,
oldDimToNewDimMap, dimReplacements);
@@ -501,6 +492,7 @@ linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
// from original shape to shape in the modified operation if needed,
// either through use of reshapes or rank-reducing slices as
// specified in `options`.
+ // Abort if one of the operands cannot be collapsed.
SmallVector<Value> newOperands;
for (OpOperand &opOperand : op->getOpOperands()) {
int64_t idx = opOperand.getOperandNumber();
@@ -508,9 +500,14 @@ linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
newOperands.push_back(opOperand.get());
continue;
}
- newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(),
- targetShapes[idx], reassociations[idx],
- options.rankReductionStrategy));
+ FailureOr<Value> collapsed =
+ options.collapseFn(rewriter, loc, opOperand.get(), targetShapes[idx],
+ reassociations[idx], options);
+ if (failed(collapsed)) {
+ // Abort if the operand could not be collapsed.
+ return failure();
+ }
+ newOperands.push_back(collapsed.value());
}
IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
@@ -518,6 +515,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
// 6. If any result type changes, insert a reshape/slice to convert from the
// original type to the new type.
+ // Abort the transformation if the result cannot be expanded back to its
+ // original shape.
SmallVector<Value> resultReplacements;
for (auto [index, result] : llvm::enumerate(replacementOp->getResults())) {
unsigned opOperandIndex = index + dpsOp.getNumDpsInputs();
@@ -526,10 +525,14 @@ linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
resultReplacements.push_back(result);
continue;
}
- Value expandedValue = expandValue(rewriter, loc, result, origDest,
- reassociations[opOperandIndex],
- options.rankReductionStrategy);
- resultReplacements.push_back(expandedValue);
+ FailureOr<Value> expanded =
+ options.expandFn(rewriter, loc, result, origDest,
+ reassociations[opOperandIndex], options);
+ if (failed(expanded)) {
+ // Abort if expansion is not successful.
+ return failure();
+ }
+ resultReplacements.push_back(expanded.value());
}
return DropUnitDimsResult{replacementOp, resultReplacements};
@@ -685,15 +688,19 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
reassociationGroup.clear();
}
- Value collapsedSource =
- collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
- reassociationMap, options.rankReductionStrategy);
+ FailureOr<Value> collapsedSource =
+ options.collapseFn(rewriter, padOp.getLoc(), padOp.getSource(),
+ newShape, reassociationMap, options);
+ if (failed(collapsedSource)) {
+ return rewriter.notifyMatchFailure(padOp, "Failed to collapse source");
+ }
auto newResultType = RankedTensorType::get(
newResultShape, padOp.getResultType().getElementType());
auto newPadOp = tensor::PadOp::create(
- rewriter, padOp.getLoc(), /*result=*/newResultType, collapsedSource,
- newLowPad, newHighPad, paddingVal, padOp.getNofold());
+ rewriter, padOp.getLoc(), /*result=*/newResultType,
+ collapsedSource.value(), newLowPad, newHighPad, paddingVal,
+ padOp.getNofold());
Value dest = padOp.getResult();
if (options.rankReductionStrategy ==
@@ -713,10 +720,13 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
padOp.getResultType().getElementType());
}
- Value expandedValue =
- expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
- reassociationMap, options.rankReductionStrategy);
- rewriter.replaceOp(padOp, expandedValue);
+ FailureOr<Value> expandedValue =
+ options.expandFn(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
+ reassociationMap, options);
+ if (failed(expandedValue)) {
+ return rewriter.notifyMatchFailure(padOp, "Failed to expand result");
+ }
+ rewriter.replaceOp(padOp, expandedValue.value());
return success();
}
@@ -904,10 +914,12 @@ static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
auto valType = cast<ShapedType>(val.getType());
SmallVector<int64_t> collapsedShape(valType.getShape());
collapsedShape.erase(collapsedShape.begin() + pos);
- return collapseValue(
+ ControlDropUnitDims control{};
+ FailureOr<Value> collapsed = control.collapseFn(
rewriter, val.getLoc(), val, collapsedShape,
- getReassociationForReshapeAtDim(valType.getRank(), pos),
- ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
+ getReassociationForReshapeAtDim(valType.getRank(), pos), control);
+ assert(llvm::succeeded(collapsed) && "Collapsing the value failed");
+ return collapsed.value();
}
/// Base class for all rank reduction patterns for contraction ops
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 9005110205630..55b47bc2e9714 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -838,8 +838,9 @@ func.func @fold_unit_dim_for_init_memref(%input: memref<1x1000xf32>) -> memref<1
// -----
-// Test that nothing changes and no assertions are fired for memrefs with affine
-// maps while still changing the other operations.
+
+// Negative test with a memref with non-identity layout.
+// The output should be identical to the input.
#accesses = [
affine_map<(i, j, k, l, m) -> (i, k, m)>,
@@ -863,24 +864,21 @@ func.func @input_stays_same(%arg0 : memref<?x1x?xf32, strided<[?, 1, 1]>>, %arg1
return %shape : memref<?x1x?x1x?xf32>
}
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK: func @input_stays_same(
-// CHECK-SAME: %[[ARG0:.*]]: memref<?x1x?xf32, strided<[?, 1, 1]>>,
-// CHECK-SAME: %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<?x1x?x1x?xf32>)
-// CHECK-SAME: -> memref<?x1x?x1x?xf32> {
-// CHECK: %[[OUT:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1], [2, 3], [4]]
-// CHECK-SAME: : memref<?x1x?x1x?xf32> into memref<?x?x?xf32>
-// CHECK: linalg.generic
-// CHECK-SAME: {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]],
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
-// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : memref<?x1x?xf32, strided<[?, 1, 1]>>, f32)
-// CHECK-SAME: outs(%[[OUT]] : memref<?x?x?xf32>) {
-// CHECK: ^bb0(%{{.*}}: f32, %[[ARG:.*]]: f32, %{{.*}}: f32):
-// CHECK: linalg.yield %[[ARG]] : f32
-// CHECK: }
-// CHECK: return %[[ARG2]] : memref<?x1x?x1x?xf32>
+// CHECK-DAG: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>
+// CHECK-DAG: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK-DAG: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d3, d4)>
+// CHECK: func.func @input_stays_same(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x1x?xf32, strided<[?, 1, 1]>>
+// CHECK-SAME: %[[ARG1:.*]]: f32
+// CHECK-SAME: %[[ARG2:.*]]: memref<?x1x?x1x?xf32>) -> memref<?x1x?x1x?xf32>
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : memref<?x1x?xf32, strided<[?, 1, 1]>>, f32)
+// CHECK-SAME: outs(%[[ARG2]] : memref<?x1x?x1x?xf32>)
+// CHECK: ^bb0(%[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32):
+// CHECK: linalg.yield %[[VAL_1]] : f32
+// CHECK: return %[[ARG2]] : memref<?x1x?x1x?xf32>
// -----
diff --git a/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir b/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
index 35eeffc1f9953..16ab3247a7d60 100644
--- a/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
+++ b/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
@@ -1,6 +1,8 @@
-// RUN: mlir-opt -test-linalg-drop-unit-dims --split-input-file %s | FileCheck %s
+// RUN: mlir-opt -test-linalg-drop-unit-dims --split-input-file %s | FileCheck %s --check-prefixes=CHECK,NOENCODE
+// RUN: mlir-opt -test-linalg-drop-unit-dims=collapse-encoded --split-input-file %s | FileCheck %s --check-prefixes=CHECK,ENCODE
// Drop only the outermost unit dimension (controlled using a control function)
+// This test does not use an encoding, therefore behavior in both modes is identical.
func.func @drop_outermost_unit_dims(%arg0: tensor<1x1x42xf32>) -> tensor<1x1x42xf32> {
%0 = tensor.empty() : tensor<1x1x42xf32>
%1 = linalg.generic {
@@ -24,3 +26,98 @@ func.func @drop_outermost_unit_dims(%arg0: tensor<1x1x42xf32>) -> tensor<1x1x42x
// CHECK-SAME: outs(%[[OUTS_RESHAPE]] :
// CHECK: %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1], [2]{{\]}}
// CHECK: return %[[EXPAND_SHAPE]]
+
+// -----
+
+// Drop outermost unit dimension with operand that has an encoding.
+// With the default behavior, the transformation is aborted and operation remains unchanged.
+// With the custom behavior, the operand gets collapsed and encoding is preserved in the collapse.
+
+#encoding = #test.tensor_encoding<"encoding">
+
+func.func @drop_unit_dims_encoded_operand(%arg0: tensor<1x1x42xf32>, %arg1: tensor<1x1x42xf32, #encoding>) -> tensor<1x1x42xf32> {
+ %0 = tensor.empty() : tensor<1x1x42xf32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<1x1x42xf32>, tensor<1x1x42xf32, #encoding>) outs(%0 : tensor<1x1x42xf32>) {
+ ^bb0(%in0: f32, %in1 : f32, %out : f32):
+ %2 = arith.addf %in0, %in1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<1x1x42xf32>
+ return %1 : tensor<1x1x42xf32>
+}
+
+// NOENCODE-LABEL: @drop_unit_dims_encoded_operand(
+// NOENCODE-SAME: %[[ARG0:.*]]: tensor<1x1x42xf32>,
+// NOENCODE-SAME: %[[ARG1:.*]]: tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>) -> tensor<1x1x42xf32> {
+// NOENCODE: %[[EMPTY_0:.*]] = tensor.empty() : tensor<1x1x42xf32>
+// NOENCODE: %[[GENERIC_0:.*]] = linalg.generic
+// NOENCODE-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// NOENCODE-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x1x42xf32>, tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>)
+// NOENCODE-SAME: outs(%[[EMPTY_0]] : tensor<1x1x42xf32>)
+// NOENCODE: return %[[GENERIC_0]] : tensor<1x1x42xf32>
+
+// ENCODE-LABEL: @drop_unit_dims_encoded_operand(
+// ENCODE-SAME: %[[ARG0:.*]]: tensor<1x1x42xf32>,
+// ENCODE-SAME: %[[ARG1:.*]]: tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>) -> tensor<1x1x42xf32> {
+// ENCODE: %[[EMPTY_0:.*]] = tensor.empty() : tensor<1x1x42xf32>
+// ENCODE: %[[COLLAPSE_SHAPE_0:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1], [2]] : tensor<1x1x42xf32> into tensor<1x42xf32>
+// ENCODE: %[[COLLAPSE_SHAPE_1:.*]] = tensor.collapse_shape %[[ARG1]] {{\[\[}}0, 1], [2]] : tensor<1x1x42xf32, #test.tensor_encoding<"encoding">> into tensor<1x42xf32, #test.tensor_encoding<"encoding">>
+// ENCODE: %[[COLLAPSE_SHAPE_2:.*]] = tensor.collapse_shape %[[EMPTY_0]] {{\[\[}}0, 1], [2]] : tensor<1x1x42xf32> into tensor<1x42xf32>
+// ENCODE: %[[GENERIC_0:.*]] = linalg.generic
+// ENCODE-SAME: iterator_types = ["parallel", "parallel"]
+// ENCODE-SAME: ins(%[[COLLAPSE_SHAPE_0]], %[[COLLAPSE_SHAPE_1]] : tensor<1x42xf32>, tensor<1x42xf32, #test.tensor_encoding<"encoding">>)
+// ENCODE-SAME: outs(%[[COLLAPSE_SHAPE_2]] : tensor<1x42xf32>)
+// ENCODE: %[[EXPAND_SHAPE_0:.*]] = tensor.expand_shape %[[GENERIC_0]] {{\[\[}}0, 1], [2]] output_shape [1, 1, 42] : tensor<1x42xf32> into tensor<1x1x42xf32>
+// ENCODE: return %[[EXPAND_SHAPE_0]] : tensor<1x1x42xf32>
+
+// -----
+
+// Drop outermost unit dimension with result that has an encoding.
+// With the default behavior, the transformation is aborted and operation remains unchanged.
+// With the custom behavior, the result gets expanded and encoding is preserved in the expansion.
+
+#encoding = #test.tensor_encoding<"encoding">
+
+func.func @drop_unit_dims_encoded_result(%arg0: tensor<1x1x42xf32>, %arg1: tensor<1x1x42xf32>) -> tensor<1x1x42xf32, #encoding> {
+ %0 = tensor.empty() : tensor<1x1x42xf32, #encoding>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<1x1x42xf32>, tensor<1x1x42xf32>) outs(%0 : tensor<1x1x42xf32, #encoding>) {
+ ^bb0(%in0: f32, %in1 : f32, %out : f32):
+ %2 = arith.addf %in0, %in1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<1x1x42xf32, #encoding>
+ return %1 : tensor<1x1x42xf32, #encoding>
+}
+
+// NOENCODE-LABEL: @drop_unit_dims_encoded_result(
+// NOENCODE-SAME: %[[ARG0:.*]]: tensor<1x1x42xf32>,
+// NOENCODE-SAME: %[[ARG1:.*]]: tensor<1x1x42xf32>) -> tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>
+// NOENCODE: %[[EMPTY_0:.*]] = tensor.empty() : tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>
+// NOENCODE: %[[GENERIC_0:.*]] = linalg.generic
+// NOENCODE-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// NOENCODE-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x1x42xf32>, tensor<1x1x42xf32>)
+// NOENCODE-SAME: outs(%[[EMPTY_0]] : tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>)
+// NOENCODE-NOT: tensor.expand_shape
+// NOENCODE: return %[[GENERIC_0]] : tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>
+
+// ENCODE-LABEL: @drop_unit_dims_encoded_result(
+// ENCODE-SAME: %[[ARG0:.*]]: tensor<1x1x42xf32>,
+// ENCODE-SAME: %[[ARG1:.*]]: tensor<1x1x42xf32>) -> tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>
+// ENCODE: %[[EMPTY_0:.*]] = tensor.empty() : tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>
+// ENCODE: %[[COLLAPSE_SHAPE_0:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1], [2]] : tensor<1x1x42xf32> into tensor<1x42xf32>
+// ENCODE: %[[COLLAPSE_SHAPE_1:.*]] = tensor.collapse_shape %[[ARG1]] {{\[\[}}0, 1], [2]] : tensor<1x1x42xf32> into tensor<1x42xf32>
+// ENCODE: %[[COLLAPSE_SHAPE_2:.*]] = tensor.collapse_shape %[[EMPTY_0]] {{\[\[}}0, 1], [2]] : tensor<1x1x42xf32, #test.tensor_encoding<"encoding">> into tensor<1x42xf32, #test.tensor_encoding<"encoding">>
+// ENCODE: %[[GENERIC_0:.*]] = linalg.generic
+// ENCODE-SAME: iterator_types = ["parallel", "parallel"]
+// ENCODE-SAME: ins(%[[COLLAPSE_SHAPE_0]], %[[COLLAPSE_SHAPE_1]] : tensor<1x42xf32>, tensor<1x42xf32>)
+// ENCODE-SAME: outs(%[[COLLAPSE_SHAPE_2]] : tensor<1x42xf32, #test.tensor_encoding<"encoding">>)
+// ENCODE: %[[EXPAND_SHAPE_0:.*]] = tensor.expand_shape %[[GENERIC_0]] {{\[\[}}0, 1], [2]] output_shape [1, 1, 42] : tensor<1x42xf32, #test.tensor_encoding<"encoding">> into tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>
+// ENCODE: return %[[EXPAND_SHAPE_0]] : tensor<1x1x42xf32, #test.tensor_encoding<"encoding">>
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
index 402ce154c0848..bf7c422aeb0f9 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
@@ -34,13 +34,72 @@ LogicalResult dropOutermostUnitDims(RewriterBase &rewriter,
return success();
}
+LogicalResult dropOutermostUnitDimsWithEncoding(RewriterBase &rewriter,
+ linalg::GenericOp genericOp) {
+ linalg::ControlDropUnitDims options;
+ options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; };
+ options.collapseFn =
+ [](RewriterBase &rewriter, Location loc, Value operand,
+ ArrayRef<int64_t> targetShape,
+ ArrayRef<ReassociationIndices> reassociation,
+ const linalg::ControlDropUnitDims &control) -> FailureOr<Value> {
+ if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
+ if (tensorType.getEncoding()) {
+ assert(control.rankReductionStrategy ==
+ linalg::ControlDropUnitDims::RankReductionStrategy::
+ ReassociativeReshape &&
+ "unexpected rank reduction strategy");
+ auto targetType = RankedTensorType::get(
+ targetShape, tensorType.getElementType(), tensorType.getEncoding());
+ return tensor::CollapseShapeOp::create(rewriter, loc, targetType,
+ operand, reassociation)
+ .getResult();
+ }
+ }
+ return linalg::collapseValue(rewriter, loc, operand, targetShape,
+ reassociation, control);
+ };
+ options.expandFn =
+ [](RewriterBase &rewriter, Location loc, Value result, Value origDest,
+ ArrayRef<ReassociationIndices> reassociation,
+ const linalg::ControlDropUnitDims &control) -> FailureOr<Value> {
+ if (auto tensorType = dyn_cast<RankedTensorType>(origDest.getType())) {
+ if (tensorType.getEncoding()) {
+ assert(control.rankReductionStrategy ==
+ linalg::ControlDropUnitDims::RankReductionStrategy::
+ ReassociativeReshape &&
+ "unexpected rank reduction strategy");
+ return tensor::ExpandShapeOp::create(rewriter, loc, tensorType, result,
+ reassociation)
+ .getResult();
+ }
+ }
+ return linalg::expandValue(rewriter, loc, result, origDest, reassociation,
+ control);
+ };
+
+ FailureOr<linalg::DropUnitDimsResult> result =
+ linalg::dropUnitDims(rewriter, genericOp, options);
+ if (failed(result)) {
+ return failure();
+ }
+ rewriter.replaceOp(genericOp, result->replacements);
+ return success();
+}
+
struct TestLinalgDropUnitDims
: public PassWrapper<TestLinalgDropUnitDims, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDropUnitDims)
TestLinalgDropUnitDims() = default;
- TestLinalgDropUnitDims(const TestLinalgDropUnitDims &pass) = default;
+ TestLinalgDropUnitDims(const TestLinalgDropUnitDims &pass)
+ : PassWrapper(pass) {}
+
+ Option<bool> collapseEncoded{
+ *this, "collapse-encoded",
+ llvm::cl::desc("Collapse and expand tensors with encodings"),
+ llvm::cl::init(false)};
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<linalg::LinalgDialect>();
@@ -63,6 +122,10 @@ struct TestLinalgDropUnitDims
for (auto genericOp : genericOps) {
rewriter.setInsertionPoint(genericOp);
+ if (collapseEncoded) {
+ (void)dropOutermostUnitDimsWithEncoding(rewriter, genericOp);
+ continue;
+ }
(void)dropOutermostUnitDims(rewriter, genericOp);
}
}
>From d74225370980133399c2087796373df9671b7799 Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Wed, 10 Dec 2025 17:24:10 +0000
Subject: [PATCH 2/2] Factor out canonicalization patterns
Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
.../Dialect/Linalg/Transforms/Transforms.h | 8 ++-
.../Linalg/Transforms/DropUnitDims.cpp | 69 +++++++++----------
2 files changed, 39 insertions(+), 38 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e47ed3f0873ad..c78824e75decc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -2089,10 +2089,16 @@ void populateFuseTensorPadWithProducerLinalgOpPatterns(
void populateSimplifyDepthwiseConvPatterns(RewritePatternSet &patterns);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
-/// tensors via reassociative reshape ops.
+/// tensors and memref.
+/// Note that these patterns should not be used with a greedy driver.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns,
ControlDropUnitDims &options);
+/// Populates canonicalization patterns that simplify IR after folding
+/// unit-extent dimensions.
+void populateFoldUnitExtentDimsCanonicalizationPatterns(
+ RewritePatternSet &patterns, ControlDropUnitDims &options);
+
/// A pattern that converts init operands to input operands.
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 0fb7ca08b0ace..b7d278a4c4d2f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -28,6 +28,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
namespace mlir {
@@ -809,33 +810,27 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
/// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting.
-static void
-populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
- ControlDropUnitDims &options) {
+void mlir::linalg::populateFoldUnitExtentDimsPatterns(
+ RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {
auto *context = patterns.getContext();
patterns.add<DropUnitDims>(context, options);
patterns.add<DropPadUnitDims>(context, options);
- // TODO: Patterns unrelated to unit dim folding should be factored out.
- patterns.add<RankReducedExtractSliceOp,
- RankReducedInsertSliceOp<tensor::InsertSliceOp>,
- RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
- context);
- linalg::FillOp::getCanonicalizationPatterns(patterns, context);
- tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
- tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
- tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
- tensor::populateFoldTensorEmptyPatterns(patterns);
- memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
- memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
-static void
-populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
- ControlDropUnitDims &options) {
+void mlir::linalg::populateFoldUnitExtentDimsCanonicalizationPatterns(
+ RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {
auto *context = patterns.getContext();
- patterns.add<DropUnitDims>(context, options);
- patterns.add<DropPadUnitDims>(context, options);
- // TODO: Patterns unrelated to unit dim folding should be factored out.
+ bool reassociativeReshape =
+ options.rankReductionStrategy ==
+ linalg::ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape;
+ if (reassociativeReshape) {
+ patterns.add<RankReducedExtractSliceOp,
+ RankReducedInsertSliceOp<tensor::InsertSliceOp>,
+ RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
+ context);
+ tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
+ tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
+ }
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
tensor::populateFoldTensorEmptyPatterns(patterns);
@@ -843,18 +838,6 @@ populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
-void mlir::linalg::populateFoldUnitExtentDimsPatterns(
- RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {
- if (options.rankReductionStrategy ==
- linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
- populateFoldUnitExtentDimsViaSlicesPatterns(patterns, options);
- } else if (options.rankReductionStrategy ==
- linalg::ControlDropUnitDims::RankReductionStrategy::
- ReassociativeReshape) {
- populateFoldUnitExtentDimsViaReshapesPatterns(patterns, options);
- }
-}
-
void mlir::linalg::populateMoveInitOperandsToInputPattern(
RewritePatternSet &patterns) {
patterns.add<MoveInitOperandsToInput>(patterns.getContext());
@@ -870,15 +853,27 @@ struct LinalgFoldUnitExtentDimsPass
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();
- RewritePatternSet patterns(context);
ControlDropUnitDims options;
if (useRankReducingSlices) {
options.rankReductionStrategy = linalg::ControlDropUnitDims::
RankReductionStrategy::ExtractInsertSlice;
}
- linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
- populateMoveInitOperandsToInputPattern(patterns);
- (void)applyPatternsGreedily(op, std::move(patterns));
+
+ // Apply fold unit extent dims patterns with walk-based driver.
+ {
+ RewritePatternSet patterns(context);
+ linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
+ walkAndApplyPatterns(op, std::move(patterns));
+ }
+
+ // Apply canonicalization patterns with greedy driver.
+ {
+ RewritePatternSet patterns(context);
+ populateMoveInitOperandsToInputPattern(patterns);
+ linalg::populateFoldUnitExtentDimsCanonicalizationPatterns(patterns,
+ options);
+ (void)applyPatternsGreedily(op, std::move(patterns));
+ }
}
};
More information about the Mlir-commits
mailing list