[Mlir-commits] [mlir] [mlir][linalg] Drop unit dims on IndexingMapOpInterface (PR #150280)
Ian Wood
llvmlistbot at llvm.org
Wed Jul 23 11:41:32 PDT 2025
https://github.com/IanWood1 updated https://github.com/llvm/llvm-project/pull/150280
>From cc8510076d54d4f520f75bab4d87dda7910509cc Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 21 Jul 2025 17:41:50 +0100
Subject: [PATCH 1/2] [mlir][linalg] Move dropUnitDims to work on
IndexingMapOpInterface
---
.../Dialect/Linalg/Transforms/Transforms.h | 12 +-
.../Linalg/Transforms/DropUnitDims.cpp | 121 +++++++++++-------
2 files changed, 88 insertions(+), 45 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 38e53648e7c34..8d4abb0d5810c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -537,10 +537,20 @@ struct ControlDropUnitDims {
return SmallVector<unsigned>{};
};
};
+
struct DropUnitDimsResult {
- linalg::GenericOp resultOp;
+ IndexingMapOpInterface resultOp;
SmallVector<Value> replacements;
};
+using DroppedUnitDimsBuilder = llvm::function_ref<IndexingMapOpInterface(
+ Location loc, OpBuilder &, IndexingMapOpInterface,
+ ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
+ const llvm::SmallDenseSet<unsigned> &droppedDims)>;
+
+FailureOr<DropUnitDimsResult>
+dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
+ DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
+ const ControlDropUnitDims &options);
FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
GenericOp genericOp,
const ControlDropUnitDims &options);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d15e61ca..1312add2f9298 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -331,14 +331,14 @@ struct UnitExtentReplacementInfo {
SmallVector<int64_t> targetShape;
};
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
- MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
+ MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand,
llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
ArrayRef<AffineExpr> dimReplacements) {
UnitExtentReplacementInfo info;
ReassociationIndices reassociationGroup;
SmallVector<AffineExpr> newIndexExprs;
- AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
- ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
+ AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
+ SmallVector<int64_t> operandShape = op.getStaticOperandShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
auto isUnitDim = [&](unsigned dim) {
@@ -380,9 +380,16 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
}
FailureOr<DropUnitDimsResult>
-linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
+ DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
const ControlDropUnitDims &options) {
- SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+ if (!dpsOp) {
+ return rewriter.notifyMatchFailure(
+ op, "op should implement DestinationStyleOpInterface");
+ }
+
+ SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray();
if (indexingMaps.empty())
return failure();
@@ -392,19 +399,19 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
AffineMap invertedMap =
inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
if (!invertedMap) {
- return rewriter.notifyMatchFailure(genericOp,
+ return rewriter.notifyMatchFailure(op,
"invalid indexing maps for operation");
}
SmallVector<int64_t> allShapesSizes;
- for (OpOperand &opOperand : genericOp->getOpOperands())
- llvm::append_range(allShapesSizes, genericOp.getShape(&opOperand));
+ for (OpOperand &opOperand : op->getOpOperands())
+ llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand));
// 1a. Get the allowed list of dimensions to drop from the `options`.
- SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
+ SmallVector<unsigned> allowedUnitDims = options.controlFn(op);
if (allowedUnitDims.empty()) {
return rewriter.notifyMatchFailure(
- genericOp, "control function returns no allowed unit dims to prune");
+ op, "control function returns no allowed unit dims to prune");
}
llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
allowedUnitDims.end());
@@ -417,19 +424,16 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
}
}
- // 2. Compute the iterator types of the modified op by dropping the one-trip
+ // 2. Compute the new loops of the modified op by dropping the one-trip
// count loops.
- SmallVector<utils::IteratorType> newIteratorTypes;
llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
SmallVector<AffineExpr> dimReplacements;
unsigned newDims = 0;
- for (auto [index, attr] :
- llvm::enumerate(genericOp.getIteratorTypesArray())) {
+ for (auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) {
if (unitDims.count(index)) {
dimReplacements.push_back(
getAffineConstantExpr(0, rewriter.getContext()));
} else {
- newIteratorTypes.push_back(attr);
oldDimToNewDimMap[index] = newDims;
dimReplacements.push_back(
getAffineDimExpr(newDims, rewriter.getContext()));
@@ -462,9 +466,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
}
return false;
};
- for (OpOperand &opOperand : genericOp->getOpOperands()) {
- auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
- ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ auto indexingMap = op.getMatchingIndexingMap(&opOperand);
+ SmallVector<int64_t> shape = op.getStaticOperandShape(&opOperand);
if (!hasCollapsibleType(opOperand)) {
AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
@@ -474,9 +478,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
reassociations.push_back({});
continue;
}
- auto replacementInfo = dropUnitExtentFromOperandMetadata(
- rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
- dimReplacements);
+ auto replacementInfo =
+ dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand,
+ oldDimToNewDimMap, dimReplacements);
reassociations.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
targetShapes.push_back(replacementInfo.targetShape);
@@ -491,13 +495,13 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
concatAffineMaps(newIndexingMaps, rewriter.getContext())))
return failure();
- Location loc = genericOp.getLoc();
+ Location loc = op.getLoc();
// 4. For each of the operands, collapse the operand to convert
// from original shape to shape in the modified operation if needed,
// either through use of reshapes or rank-reducing slices as
// specified in `options`.
SmallVector<Value> newOperands;
- for (OpOperand &opOperand : genericOp->getOpOperands()) {
+ for (OpOperand &opOperand : op->getOpOperands()) {
int64_t idx = opOperand.getOperandNumber();
if (!collapsed[idx]) {
newOperands.push_back(opOperand.get());
@@ -508,31 +512,15 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
options.rankReductionStrategy));
}
- // 5. Create the `linalg.generic` operation with the new operands,
- // indexing maps, iterator types and result types.
- ArrayRef<Value> newInputs =
- ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
- ArrayRef<Value> newOutputs =
- ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
- SmallVector<Type> resultTypes;
- resultTypes.reserve(genericOp.getNumResults());
- for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
- resultTypes.push_back(newOutputs[i].getType());
- GenericOp replacementOp =
- rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
- newIndexingMaps, newIteratorTypes);
- rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
- replacementOp.getRegion().begin());
- // 5a. Replace `linalg.index` operations that refer to the dropped unit
- // dimensions.
- replaceUnitDimIndexOps(replacementOp, unitDims, rewriter);
+ IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
+ loc, rewriter, op, newOperands, newIndexingMaps, unitDims);
// 6. If any result type changes, insert a reshape/slice to convert from the
// original type to the new type.
SmallVector<Value> resultReplacements;
- for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) {
- unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
- Value origDest = genericOp.getDpsInitOperand(index)->get();
+ for (auto [index, result] : llvm::enumerate(replacementOp->getResults())) {
+ unsigned opOperandIndex = index + dpsOp.getNumDpsInputs();
+ Value origDest = dpsOp.getDpsInitOperand(index)->get();
if (!collapsed[opOperandIndex]) {
resultReplacements.push_back(result);
continue;
@@ -546,6 +534,51 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
return DropUnitDimsResult{replacementOp, resultReplacements};
}
+FailureOr<DropUnitDimsResult>
+linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+ const ControlDropUnitDims &options) {
+
+ DroppedUnitDimsBuilder build =
+ [](Location loc, OpBuilder &b, IndexingMapOpInterface op,
+ ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
+ const llvm::SmallDenseSet<unsigned> &droppedDims)
+ -> IndexingMapOpInterface {
+ auto genericOp = cast<GenericOp>(op);
+ // Compute the iterator types of the modified op by dropping the one-trip
+ // count loops.
+ SmallVector<utils::IteratorType> newIteratorTypes;
+ for (auto [index, attr] :
+ llvm::enumerate(genericOp.getIteratorTypesArray())) {
+ if (!droppedDims.count(index))
+ newIteratorTypes.push_back(attr);
+ }
+
+ // Create the `linalg.generic` operation with the new operands,
+ // indexing maps, iterator types and result types.
+ ArrayRef<Value> newInputs =
+ ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
+ ArrayRef<Value> newOutputs =
+ ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
+ SmallVector<Type> resultTypes;
+ resultTypes.reserve(genericOp.getNumResults());
+ for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
+ resultTypes.push_back(newOutputs[i].getType());
+ GenericOp replacementOp =
+ b.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
+ newIndexingMaps, newIteratorTypes);
+ b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
+ replacementOp.getRegion().begin());
+ // 5a. Replace `linalg.index` operations that refer to the dropped unit
+ // dimensions.
+ IRRewriter rewriter(b);
+ replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter);
+
+ return replacementOp;
+ };
+
+ return dropUnitDims(rewriter, genericOp, build, options);
+}
+
namespace {
struct DropUnitDims : public OpRewritePattern<GenericOp> {
DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
>From 425861e2769844e975e45b23fb54f8139ed5724b Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood at u.northwestern.edu>
Date: Wed, 23 Jul 2025 10:38:33 -0700
Subject: [PATCH 2/2] Fix Wdangling
Signed-off-by: Ian Wood <ianwood at u.northwestern.edu>
---
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 4 ++--
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8d4abb0d5810c..e625eefac5f78 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -542,14 +542,14 @@ struct DropUnitDimsResult {
IndexingMapOpInterface resultOp;
SmallVector<Value> replacements;
};
-using DroppedUnitDimsBuilder = llvm::function_ref<IndexingMapOpInterface(
+using DroppedUnitDimsBuilder = std::function<IndexingMapOpInterface(
Location loc, OpBuilder &, IndexingMapOpInterface,
ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
const llvm::SmallDenseSet<unsigned> &droppedDims)>;
FailureOr<DropUnitDimsResult>
dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
- DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
+ const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
const ControlDropUnitDims &options);
FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
GenericOp genericOp,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 1312add2f9298..6c59cd65c1b99 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -381,7 +381,7 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
FailureOr<DropUnitDimsResult>
linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
- DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
+ const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
const ControlDropUnitDims &options) {
auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation());
if (!dpsOp) {
More information about the Mlir-commits
mailing list