[Mlir-commits] [mlir] [mlir][Linalg] Allow more control in drop unit dims (PR #171796)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 11 02:31:44 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg
Author: Lukas Sommer (sommerlukas)
<details>
<summary>Changes</summary>
Extend the ControlDropUnitDims struct to allow users of the `linalg::dropUnitDims` function more control over the behavior of the function.
The extended struct allows users to specify functions to control how the operands are collapsed and how the result is expanded to the original shape.
One example (and the motivation for this change) where this additional control is useful is to allow collapsing of tensors with an encoding, as demonstrated by the new test.
This is a breaking change. The new default behavior changes to abort the transformation if one of the operands cannot be collapsed or if the result cannot be expanded. This is the case for `memref` with non-identity layout and `tensor`s with an encoding.
---
Patch is 39.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171796.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+104-1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+105-98)
- (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+18-20)
- (modified) mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir (+98-1)
- (modified) mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp (+64-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d00183a1e16a1..c78824e75decc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -516,6 +516,28 @@ LogicalResult vectorizeOpPrecondition(Operation *op,
using LinalgLoops = SmallVector<Operation *, 4>;
+// Forward declaration
+struct ControlDropUnitDims;
+
+/// Collapse the given \p value to \p targetShape. The \p reassociation is used
+/// when `rankReductionStrategy` of \p control is set to
+/// `RankReductionStrategy::ReassociativeReshape`. Will return failure if the
+/// operand has memref type with a non-identity layout or tensor type with an
+/// encoding.
+FailureOr<Value> collapseValue(RewriterBase &rewriter, Location loc,
+ Value operand, ArrayRef<int64_t> targetShape,
+ ArrayRef<ReassociationIndices> reassociation,
+ const ControlDropUnitDims &control);
+
+/// Expand the given \p value so that the type matches the type of \p origDest.
+/// The \p reassociation is used when `rankReductionStrategy` of \p control is
+/// set to `RankReductionStrategy::ReassociativeReshape`. Will return failure if
+/// the original destination has tensor type with an encoding.
+FailureOr<Value> expandValue(RewriterBase &rewriter, Location loc, Value result,
+ Value origDest,
+ ArrayRef<ReassociationIndices> reassociation,
+ const ControlDropUnitDims &control);
+
/// Transformation to drop unit-extent dimensions from `linalg.generic`
/// operations.
struct ControlDropUnitDims {
@@ -524,7 +546,19 @@ struct ControlDropUnitDims {
RankReductionStrategy rankReductionStrategy =
RankReductionStrategy::ReassociativeReshape;
+ /// Instances of this type are used to control which dimensions of an operand
+ /// are considered for dropping unit extent dimensions. The parameter to the
+ /// function is the operation itself, the expected return is a list of
+ /// dimensions to consider for dropping unit extent dimensions. If the
+ /// operation should not be have any dimensions dropped, implementations
+ /// should return an empty list.
using ControlFnTy = std::function<SmallVector<unsigned>(Operation *)>;
+
+ /// Function to control which dimensions, if any, are to be considered for
+ /// dropping unit extent dimensions. The default behavior is to consider all
+ /// dimensions of a \c linalg.generic or \c tensor.pad operation for dropping.
+ /// Users of the \ref dropUnitDims interface can override the default behavior
+ /// by setting this member to their own implementation.
ControlFnTy controlFn = [](Operation *op) {
if (auto genericOp = dyn_cast_or_null<GenericOp>(op)) {
return llvm::to_vector(llvm::seq<unsigned>(0, genericOp.getNumLoops()));
@@ -535,6 +569,58 @@ struct ControlDropUnitDims {
}
return SmallVector<unsigned>{};
};
+
+ /// Instances of this type are used to control how operand values are
+ /// collapsed after dropping unit extent dimensions. Next to the control
+ /// struct, rewriter and location, the function receives the operand value to
+ /// collapse, the new target shape and how old dimensions should be grouped.
+ /// The function needs to insert the necessary operations to collapse the
+ /// operand to the target shape and returns the new operand value.
+ /// If the operand should not be collapsed, the function should return
+ /// failure, leading to the transformation to be aborted.
+ using CollapseFnTy = std::function<FailureOr<Value>(
+ RewriterBase &, Location, Value, ArrayRef<int64_t>,
+ ArrayRef<ReassociationIndices>, const ControlDropUnitDims &)>;
+
+ /// Function to control how operands are collapsed into their new target shape
+ /// after dropping unit extent dimensions. For the default behavior
+ /// \see linalg::collapseValue.
+ /// Users of the \ref dropUnitDims interface can override the default behavior
+ /// by setting this member to their own implementation.
+ CollapseFnTy collapseFn =
+ [](RewriterBase &rewriter, Location loc, Value operand,
+ ArrayRef<int64_t> targetShape,
+ ArrayRef<ReassociationIndices> reassociation,
+ const ControlDropUnitDims &control) -> FailureOr<Value> {
+ return linalg::collapseValue(rewriter, loc, operand, targetShape,
+ reassociation, control);
+ };
+
+ /// Instances of this type are used to control how result values are expanded
+ /// into their original shape after dropping unit extent dimensions. Next to
+ /// the control construct, rewriter and location, the function recieves the
+ /// result value, the original value to replace and and information on how the
+ /// new dimensions were grouped.
+ /// The function needs to insert the necessary operations to expand the
+ /// result to the original shape and returns the new result value.
+ /// If the result should not be expanded, the function should return
+ /// failure, leading to the transformation to be aborted.
+ using ExpandFnTy = std::function<FailureOr<Value>(
+ RewriterBase &, Location, Value, Value, ArrayRef<ReassociationIndices>,
+ const ControlDropUnitDims &)>;
+
+ /// Function to control how results are expanded into their original shape
+ /// after dropping unit extent dimensions. The default behavior
+ /// \see linalg::expandValue.
+ /// Users of the \ref dropUnitDims interface can override the default behavior
+ /// by setting this member to their own implementation.
+ ExpandFnTy expandFn =
+ [](RewriterBase &rewriter, Location loc, Value result, Value origDest,
+ ArrayRef<ReassociationIndices> reassociation,
+ const ControlDropUnitDims &control) -> FailureOr<Value> {
+ return linalg::expandValue(rewriter, loc, result, origDest, reassociation,
+ control);
+ };
};
struct DropUnitDimsResult {
@@ -546,10 +632,21 @@ using DroppedUnitDimsBuilder = std::function<IndexingMapOpInterface(
ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
const llvm::SmallDenseSet<unsigned> &droppedDims)>;
+/// Drop unit extent dimensions from the \p op and its operands.
+/// The transformation is aborted if unit dimensions cannot be dropped from any
+/// of the operands. Note that this function may insert trivially dead
+/// operations if the transformation is aborted and should therefore not be
+/// called from greedy drivers.
FailureOr<DropUnitDimsResult>
dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
const ControlDropUnitDims &options);
+
+/// Drop unit extent dimensions from the \p genericOp and its operands.
+/// The transformation is aborted if unit dimensions cannot be dropped from any
+/// of the operands. Note that this function may insert trivially dead
+/// operations if the transformation is aborted and should therefore not be
+/// called from greedy drivers.
FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
GenericOp genericOp,
const ControlDropUnitDims &options);
@@ -1992,10 +2089,16 @@ void populateFuseTensorPadWithProducerLinalgOpPatterns(
void populateSimplifyDepthwiseConvPatterns(RewritePatternSet &patterns);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
-/// tensors via reassociative reshape ops.
+/// tensors and memref.
+/// Note that these patterns should not be used with a greedy driver.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns,
ControlDropUnitDims &options);
+/// Populates canonicalization patterns that simplify IR after folding
+/// unit-extent dimensions.
+void populateFoldUnitExtentDimsCanonicalizationPatterns(
+ RewritePatternSet &patterns, ControlDropUnitDims &options);
+
/// A pattern that converts init operands to input operands.
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 9e6c1e6036cba..b7d278a4c4d2f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -28,6 +28,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
namespace mlir {
@@ -244,16 +245,19 @@ replaceUnitDimIndexOps(GenericOp genericOp,
}
}
-/// Expand the given `value` so that the type matches the type of `origDest`.
-/// The `reassociation` is used when `rankReductionStrategy` is set to
-/// `RankReductionStrategy::ReassociativeReshape`.
-static Value
-expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
- ArrayRef<ReassociationIndices> reassociation,
- ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
+FailureOr<Value>
+linalg::expandValue(RewriterBase &rewriter, Location loc, Value result,
+ Value origDest,
+ ArrayRef<ReassociationIndices> reassociation,
+ const ControlDropUnitDims &control) {
// There are no results for memref outputs.
auto origResultType = cast<RankedTensorType>(origDest.getType());
- if (rankReductionStrategy ==
+ origResultType.dump();
+ if (origResultType.getEncoding() != nullptr) {
+ // Do not expand tensors with encoding.
+ return failure();
+ }
+ if (control.rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
unsigned rank = origResultType.getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
@@ -264,7 +268,7 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
loc, result, origDest, offsets, sizes, strides);
}
- assert(rankReductionStrategy ==
+ assert(control.rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
return tensor::ExpandShapeOp::create(rewriter, loc, origResultType, result,
@@ -272,15 +276,17 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
.getResult();
}
-/// Collapse the given `value` so that the type matches the type of
-/// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
-/// set to `RankReductionStrategy::ReassociativeReshape`.
-static Value collapseValue(
- RewriterBase &rewriter, Location loc, Value operand,
- ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
- ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
+FailureOr<Value>
+linalg::collapseValue(RewriterBase &rewriter, Location loc, Value operand,
+ ArrayRef<int64_t> targetShape,
+ ArrayRef<ReassociationIndices> reassociation,
+ const ControlDropUnitDims &control) {
if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
- if (rankReductionStrategy ==
+ if (!memrefType.getLayout().isIdentity()) {
+ // Do not collapse memrefs with a non-identity layout.
+ return failure();
+ }
+ if (control.rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
FailureOr<Value> rankReducingExtract =
memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
@@ -290,17 +296,22 @@ static Value collapseValue(
}
assert(
- rankReductionStrategy ==
+ control.rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
MemRefLayoutAttrInterface layout;
auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
layout, memrefType.getMemorySpace());
return memref::CollapseShapeOp::create(rewriter, loc, targetType, operand,
- reassociation);
+ reassociation)
+ .getResult();
}
if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
- if (rankReductionStrategy ==
+ if (tensorType.getEncoding() != nullptr) {
+ // Do not collapse tensors with an encoding.
+ return failure();
+ }
+ if (control.rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
FailureOr<Value> rankReducingExtract =
tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
@@ -310,13 +321,14 @@ static Value collapseValue(
}
assert(
- rankReductionStrategy ==
+ control.rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
auto targetType =
RankedTensorType::get(targetShape, tensorType.getElementType());
return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand,
- reassociation);
+ reassociation)
+ .getResult();
}
llvm_unreachable("unsupported operand type");
}
@@ -457,28 +469,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
SmallVector<SmallVector<ReassociationIndices>> reassociations;
SmallVector<SmallVector<int64_t>> targetShapes;
SmallVector<bool> collapsed;
- auto hasCollapsibleType = [](OpOperand &operand) {
- Type operandType = operand.get().getType();
- if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
- return memrefOperandType.getLayout().isIdentity();
- }
- if (auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
- return tensorOperandType.getEncoding() == nullptr;
- }
- return false;
- };
for (OpOperand &opOperand : op->getOpOperands()) {
auto indexingMap = op.getMatchingIndexingMap(&opOperand);
- SmallVector<int64_t> shape = op.getStaticOperandShape(&opOperand);
- if (!hasCollapsibleType(opOperand)) {
- AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
- dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
- newIndexingMaps.push_back(newIndexingMap);
- targetShapes.push_back(llvm::to_vector(shape));
- collapsed.push_back(false);
- reassociations.push_back({});
- continue;
- }
auto replacementInfo =
dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand,
oldDimToNewDimMap, dimReplacements);
@@ -501,6 +493,7 @@ linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
// from original shape to shape in the modified operation if needed,
// either through use of reshapes or rank-reducing slices as
// specified in `options`.
+ // Abort if one of the operands cannot be collapsed.
SmallVector<Value> newOperands;
for (OpOperand &opOperand : op->getOpOperands()) {
int64_t idx = opOperand.getOperandNumber();
@@ -508,9 +501,14 @@ linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
newOperands.push_back(opOperand.get());
continue;
}
- newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(),
- targetShapes[idx], reassociations[idx],
- options.rankReductionStrategy));
+ FailureOr<Value> collapsed =
+ options.collapseFn(rewriter, loc, opOperand.get(), targetShapes[idx],
+ reassociations[idx], options);
+ if (failed(collapsed)) {
+ // Abort if the operand could not be collapsed.
+ return failure();
+ }
+ newOperands.push_back(collapsed.value());
}
IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
@@ -518,6 +516,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
// 6. If any result type changes, insert a reshape/slice to convert from the
// original type to the new type.
+ // Abort the transformation if the result cannot be expanded back to its
+ // original shape.
SmallVector<Value> resultReplacements;
for (auto [index, result] : llvm::enumerate(replacementOp->getResults())) {
unsigned opOperandIndex = index + dpsOp.getNumDpsInputs();
@@ -526,10 +526,14 @@ linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
resultReplacements.push_back(result);
continue;
}
- Value expandedValue = expandValue(rewriter, loc, result, origDest,
- reassociations[opOperandIndex],
- options.rankReductionStrategy);
- resultReplacements.push_back(expandedValue);
+ FailureOr<Value> expanded =
+ options.expandFn(rewriter, loc, result, origDest,
+ reassociations[opOperandIndex], options);
+ if (failed(expanded)) {
+ // Abort if expansion is not successful.
+ return failure();
+ }
+ resultReplacements.push_back(expanded.value());
}
return DropUnitDimsResult{replacementOp, resultReplacements};
@@ -685,15 +689,19 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
reassociationGroup.clear();
}
- Value collapsedSource =
- collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
- reassociationMap, options.rankReductionStrategy);
+ FailureOr<Value> collapsedSource =
+ options.collapseFn(rewriter, padOp.getLoc(), padOp.getSource(),
+ newShape, reassociationMap, options);
+ if (failed(collapsedSource)) {
+ return rewriter.notifyMatchFailure(padOp, "Failed to collapse source");
+ }
auto newResultType = RankedTensorType::get(
newResultShape, padOp.getResultType().getElementType());
auto newPadOp = tensor::PadOp::create(
- rewriter, padOp.getLoc(), /*result=*/newResultType, collapsedSource,
- newLowPad, newHighPad, paddingVal, padOp.getNofold());
+ rewriter, padOp.getLoc(), /*result=*/newResultType,
+ collapsedSource.value(), newLowPad, newHighPad, paddingVal,
+ padOp.getNofold());
Value dest = padOp.getResult();
if (options.rankReductionStrategy ==
@@ -713,10 +721,13 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
padOp.getResultType().getElementType());
}
- Value expandedValue =
- expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
- reassociationMap, options.rankReductionStrategy);
- rewriter.replaceOp(padOp, expandedValue);
+ FailureOr<Value> expandedValue =
+ options.expandFn(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
+ reassociationMap, options);
+ if (failed(expandedValue)) {
+ return rewriter.notifyMatchFailure(padOp, "Failed to expand result");
+ }
+ rewriter.replaceOp(padOp, expandedValue.value());
return success();
}
@@ -799,33 +810,27 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
/// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting.
-static void
-populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
- ControlDropUnitDims &options) {
+void mlir::linalg::populateFoldUnitExtentDimsPatterns(
+ RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {
auto *context = patterns.getContext();
patterns.add<DropUnitDims>(context, options);
patterns.add<DropPadUnitDims>(context, options);
- // TODO: Patterns unrelated to unit dim folding should be factored out.
- patterns.add<RankReducedExtractSliceOp,
- ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/171796
More information about the Mlir-commits
mailing list