[Mlir-commits] [mlir] 3ebe5d6 - [mlir][linalg] Drop unit dims on IndexingMapOpInterface (#150280)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 24 08:07:53 PDT 2025
Author: Ian Wood
Date: 2025-07-24T16:07:51+01:00
New Revision: 3ebe5d661f7829b2ffe1b422ec7d00d3213c9730
URL: https://github.com/llvm/llvm-project/commit/3ebe5d661f7829b2ffe1b422ec7d00d3213c9730
DIFF: https://github.com/llvm/llvm-project/commit/3ebe5d661f7829b2ffe1b422ec7d00d3213c9730.diff
LOG: [mlir][linalg] Drop unit dims on IndexingMapOpInterface (#150280)
Generalizes `dropUnitDims` to operate on any op implementing the
`IndexingMapOpInterface`. Operation specific creation is handled by
passing a builder that will construct the new operation based on the
dropped dimensions.
---------
Signed-off-by: Ian Wood <ianwood at u.northwestern.edu>
Co-authored-by: Kunwar Grover <groverkss at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 38e53648e7c34..e625eefac5f78 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -537,10 +537,20 @@ struct ControlDropUnitDims {
return SmallVector<unsigned>{};
};
};
+
struct DropUnitDimsResult {
- linalg::GenericOp resultOp;
+ IndexingMapOpInterface resultOp;
SmallVector<Value> replacements;
};
+using DroppedUnitDimsBuilder = std::function<IndexingMapOpInterface(
+ Location loc, OpBuilder &, IndexingMapOpInterface,
+ ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
+ const llvm::SmallDenseSet<unsigned> &droppedDims)>;
+
+FailureOr<DropUnitDimsResult>
+dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
+ const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
+ const ControlDropUnitDims &options);
FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
GenericOp genericOp,
const ControlDropUnitDims &options);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d15e61ca..6c59cd65c1b99 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -331,14 +331,14 @@ struct UnitExtentReplacementInfo {
SmallVector<int64_t> targetShape;
};
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
- MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
+ MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand,
llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
ArrayRef<AffineExpr> dimReplacements) {
UnitExtentReplacementInfo info;
ReassociationIndices reassociationGroup;
SmallVector<AffineExpr> newIndexExprs;
- AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
- ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
+ AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
+ SmallVector<int64_t> operandShape = op.getStaticOperandShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
auto isUnitDim = [&](unsigned dim) {
@@ -380,9 +380,16 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
}
FailureOr<DropUnitDimsResult>
-linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
+ const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
const ControlDropUnitDims &options) {
- SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+ if (!dpsOp) {
+ return rewriter.notifyMatchFailure(
+ op, "op should implement DestinationStyleOpInterface");
+ }
+
+ SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray();
if (indexingMaps.empty())
return failure();
@@ -392,19 +399,19 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
AffineMap invertedMap =
inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
if (!invertedMap) {
- return rewriter.notifyMatchFailure(genericOp,
+ return rewriter.notifyMatchFailure(op,
"invalid indexing maps for operation");
}
SmallVector<int64_t> allShapesSizes;
- for (OpOperand &opOperand : genericOp->getOpOperands())
- llvm::append_range(allShapesSizes, genericOp.getShape(&opOperand));
+ for (OpOperand &opOperand : op->getOpOperands())
+ llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand));
// 1a. Get the allowed list of dimensions to drop from the `options`.
- SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
+ SmallVector<unsigned> allowedUnitDims = options.controlFn(op);
if (allowedUnitDims.empty()) {
return rewriter.notifyMatchFailure(
- genericOp, "control function returns no allowed unit dims to prune");
+ op, "control function returns no allowed unit dims to prune");
}
llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
allowedUnitDims.end());
@@ -417,19 +424,16 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
}
}
- // 2. Compute the iterator types of the modified op by dropping the one-trip
+ // 2. Compute the new loops of the modified op by dropping the one-trip
// count loops.
- SmallVector<utils::IteratorType> newIteratorTypes;
llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
SmallVector<AffineExpr> dimReplacements;
unsigned newDims = 0;
- for (auto [index, attr] :
- llvm::enumerate(genericOp.getIteratorTypesArray())) {
+ for (auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) {
if (unitDims.count(index)) {
dimReplacements.push_back(
getAffineConstantExpr(0, rewriter.getContext()));
} else {
- newIteratorTypes.push_back(attr);
oldDimToNewDimMap[index] = newDims;
dimReplacements.push_back(
getAffineDimExpr(newDims, rewriter.getContext()));
@@ -462,9 +466,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
}
return false;
};
- for (OpOperand &opOperand : genericOp->getOpOperands()) {
- auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
- ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ auto indexingMap = op.getMatchingIndexingMap(&opOperand);
+ SmallVector<int64_t> shape = op.getStaticOperandShape(&opOperand);
if (!hasCollapsibleType(opOperand)) {
AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
@@ -474,9 +478,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
reassociations.push_back({});
continue;
}
- auto replacementInfo = dropUnitExtentFromOperandMetadata(
- rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
- dimReplacements);
+ auto replacementInfo =
+ dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand,
+ oldDimToNewDimMap, dimReplacements);
reassociations.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
targetShapes.push_back(replacementInfo.targetShape);
@@ -491,13 +495,13 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
concatAffineMaps(newIndexingMaps, rewriter.getContext())))
return failure();
- Location loc = genericOp.getLoc();
+ Location loc = op.getLoc();
// 4. For each of the operands, collapse the operand to convert
// from original shape to shape in the modified operation if needed,
// either through use of reshapes or rank-reducing slices as
// specified in `options`.
SmallVector<Value> newOperands;
- for (OpOperand &opOperand : genericOp->getOpOperands()) {
+ for (OpOperand &opOperand : op->getOpOperands()) {
int64_t idx = opOperand.getOperandNumber();
if (!collapsed[idx]) {
newOperands.push_back(opOperand.get());
@@ -508,31 +512,15 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
options.rankReductionStrategy));
}
- // 5. Create the `linalg.generic` operation with the new operands,
- // indexing maps, iterator types and result types.
- ArrayRef<Value> newInputs =
- ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
- ArrayRef<Value> newOutputs =
- ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
- SmallVector<Type> resultTypes;
- resultTypes.reserve(genericOp.getNumResults());
- for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
- resultTypes.push_back(newOutputs[i].getType());
- GenericOp replacementOp =
- rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
- newIndexingMaps, newIteratorTypes);
- rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
- replacementOp.getRegion().begin());
- // 5a. Replace `linalg.index` operations that refer to the dropped unit
- // dimensions.
- replaceUnitDimIndexOps(replacementOp, unitDims, rewriter);
+ IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
+ loc, rewriter, op, newOperands, newIndexingMaps, unitDims);
// 6. If any result type changes, insert a reshape/slice to convert from the
// original type to the new type.
SmallVector<Value> resultReplacements;
- for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) {
- unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
- Value origDest = genericOp.getDpsInitOperand(index)->get();
+ for (auto [index, result] : llvm::enumerate(replacementOp->getResults())) {
+ unsigned opOperandIndex = index + dpsOp.getNumDpsInputs();
+ Value origDest = dpsOp.getDpsInitOperand(index)->get();
if (!collapsed[opOperandIndex]) {
resultReplacements.push_back(result);
continue;
@@ -546,6 +534,51 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
return DropUnitDimsResult{replacementOp, resultReplacements};
}
+FailureOr<DropUnitDimsResult>
+linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+ const ControlDropUnitDims &options) {
+
+ DroppedUnitDimsBuilder build =
+ [](Location loc, OpBuilder &b, IndexingMapOpInterface op,
+ ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
+ const llvm::SmallDenseSet<unsigned> &droppedDims)
+ -> IndexingMapOpInterface {
+ auto genericOp = cast<GenericOp>(op);
+ // Compute the iterator types of the modified op by dropping the one-trip
+ // count loops.
+ SmallVector<utils::IteratorType> newIteratorTypes;
+ for (auto [index, attr] :
+ llvm::enumerate(genericOp.getIteratorTypesArray())) {
+ if (!droppedDims.count(index))
+ newIteratorTypes.push_back(attr);
+ }
+
+ // Create the `linalg.generic` operation with the new operands,
+ // indexing maps, iterator types and result types.
+ ArrayRef<Value> newInputs =
+ ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
+ ArrayRef<Value> newOutputs =
+ ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
+ SmallVector<Type> resultTypes;
+ resultTypes.reserve(genericOp.getNumResults());
+ for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
+ resultTypes.push_back(newOutputs[i].getType());
+ GenericOp replacementOp =
+ b.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
+ newIndexingMaps, newIteratorTypes);
+ b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
+ replacementOp.getRegion().begin());
+ // 5a. Replace `linalg.index` operations that refer to the dropped unit
+ // dimensions.
+ IRRewriter rewriter(b);
+ replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter);
+
+ return replacementOp;
+ };
+
+ return dropUnitDims(rewriter, genericOp, build, options);
+}
+
namespace {
struct DropUnitDims : public OpRewritePattern<GenericOp> {
DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
More information about the Mlir-commits
mailing list