[Mlir-commits] [mlir] [mlir][Linalg] NFC: Expose a method to deduplicate operands/remove dead results of `linalg.generic` op. (PR #125141)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 31 17:10:56 PST 2025
================
@@ -52,255 +52,266 @@ static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
return true;
}
-namespace {
-
-struct DeduplicateAndRemoveDeadOperandsAndResults
- : public OpRewritePattern<GenericOp> {
- DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
- bool removeOutputs)
- : OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}
-
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override {
- // Create a map from argument position in the original op to the argument
- // position in the new op. If the argument is dropped it wont have an entry.
- SmallVector<OpOperand *> droppedOpOperands;
-
- // Information needed to build the new op.
- SmallVector<Value> newInputOperands, newOutputOperands;
- SmallVector<AffineMap> newIndexingMaps;
-
- // Gather information about duplicate input operands.
- llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
- deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
- newIndexingMaps);
-
- // Gather information about the dropped outputs.
- llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
- deduplicateOutputOperands(genericOp, droppedOpOperands,
- newOutputOperands, newIndexingMaps);
-
- // Check if there is any change to operands.
- if (newInputOperands.size() + newOutputOperands.size() ==
- genericOp->getNumOperands())
- return failure();
-
- // Create the new op with the body being empty.
- Location loc = genericOp.getLoc();
- SmallVector<Type> newResultTypes;
- for (Value v : newOutputOperands)
- if (isa<TensorType>(v.getType()))
- newResultTypes.push_back(v.getType());
- auto newOp = rewriter.create<GenericOp>(
- loc, newResultTypes, newInputOperands, newOutputOperands,
- rewriter.getAffineMapArrayAttr(newIndexingMaps),
- genericOp.getIteratorTypes(), genericOp.getDocAttr(),
- genericOp.getLibraryCallAttr(),
- [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
- return;
- });
- // Copy over unknown attributes. They might be load bearing for some flow.
- ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
- for (NamedAttribute kv : genericOp->getAttrs())
- if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
- newOp->setAttr(kv.getName(), kv.getValue());
-
- // Fix up the payload of the canonicalized operation.
- populateOpPayload(genericOp, newOp, origInsToNewInsPos,
- origOutsToNewOutsPos, rewriter);
-
- // Replace all live uses of the op.
- SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
- for (const auto &result : llvm::enumerate(genericOp.getResults())) {
- auto it = origOutsToNewOutsPos.find(result.index());
- if (it == origOutsToNewOutsPos.end())
+//===---------------------------------------------------------------------===//
+// Helper methods for operand deduplication and dead results elimination
+//===---------------------------------------------------------------------===//
+
+// Deduplicate input operands, and return the
+// - Mapping from operand position in the original op, to operand position in
+// the canonicalized op.
+// - The preserved input operands list (by reference).
+llvm::SmallDenseMap<unsigned, unsigned> static deduplicateInputOperands(
+ GenericOp genericOp, SmallVector<OpOperand *> &droppedOpOperands,
+ SmallVector<Value> &newInputOperands,
+ SmallVector<AffineMap> &newIndexingMaps) {
+ llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
+ llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
+ for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ OpOperand *inputOpOperand = en.value();
+ // Check if operand is dead and if dropping the indexing map makes the
+ // loops to shape computation invalid.
+ if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
+ // Add the current operands to the list of potentially droppable
+ // operands. If it cannot be dropped, this needs to be popped back.
+ droppedOpOperands.push_back(inputOpOperand);
+ if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
continue;
- replacementsVals[result.index()] = newOp.getResult(it->second);
+ droppedOpOperands.pop_back();
}
- rewriter.replaceOp(genericOp, replacementsVals);
- return success();
- }
-private:
- /// If unset, outputs are not modified by this pattern.
- bool removeOutputs;
-
- // Deduplicate input operands, and return the
- // - Mapping from operand position in the original op, to operand position in
- // the canonicalized op.
- // - The preserved input operands list (by reference).
- llvm::SmallDenseMap<unsigned, unsigned>
- deduplicateInputOperands(GenericOp genericOp,
- SmallVector<OpOperand *> &droppedOpOperands,
- SmallVector<Value> &newInputOperands,
- SmallVector<AffineMap> &newIndexingMaps) const {
- llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
- llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
- for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
- OpOperand *inputOpOperand = en.value();
- // Check if operand is dead and if dropping the indexing map makes the
- // loops to shape computation invalid.
- if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
- // Add the current operands to the list of potentially droppable
- // operands. If it cannot be dropped, this needs to be popped back.
- droppedOpOperands.push_back(inputOpOperand);
- if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
- continue;
- droppedOpOperands.pop_back();
- }
+ // Check if this operand is a duplicate.
+ AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
+ auto it =
+ dedupedInputs.find(std::make_pair(inputOpOperand->get(), indexingMap));
+ if (it != dedupedInputs.end()) {
+ origToNewPos[en.index()] = it->second;
+ droppedOpOperands.push_back(inputOpOperand);
+ continue;
+ }
- // Check if this operand is a duplicate.
- AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
- auto it = dedupedInputs.find(
- std::make_pair(inputOpOperand->get(), indexingMap));
- if (it != dedupedInputs.end()) {
- origToNewPos[en.index()] = it->second;
- droppedOpOperands.push_back(inputOpOperand);
- continue;
- }
+ // This is a preserved argument.
+ origToNewPos[en.index()] = newInputOperands.size();
+ dedupedInputs[{inputOpOperand->get(), indexingMap}] =
+ newInputOperands.size();
+ newInputOperands.push_back(inputOpOperand->get());
+ newIndexingMaps.push_back(indexingMap);
+ }
+ return origToNewPos;
+}
- // This is a preserved argument.
- origToNewPos[en.index()] = newInputOperands.size();
- dedupedInputs[{inputOpOperand->get(), indexingMap}] =
- newInputOperands.size();
- newInputOperands.push_back(inputOpOperand->get());
- newIndexingMaps.push_back(indexingMap);
+// Deduplicate output operands, and return the
+// - Mapping from operand position in the original op, to operand position in
+// the canonicalized op.
+// - The preserved output operands list (by reference).
+llvm::SmallDenseMap<unsigned, unsigned> static deduplicateOutputOperands(
+ GenericOp genericOp, SmallVector<OpOperand *> &droppedOpOperands,
+ SmallVector<Value> &newOutputOperands,
+ SmallVector<AffineMap> &newIndexingMaps, bool removeOutputs) {
+ llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
+ llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
+ dedupedOutpts;
+ // If the op doesn't have tensor semantics or outputs should not be removed,
+ // keep all the outputs as preserved.
+ if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
+ for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
+ origToNewPos[en.index()] = newOutputOperands.size();
+ newOutputOperands.push_back(en.value().get());
+ newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&en.value()));
}
return origToNewPos;
}
-
- // Deduplicate output operands, and return the
- // - Mapping from operand position in the original op, to operand position in
- // the canonicalized op.
- // - The preserved output operands list (by reference).
- llvm::SmallDenseMap<unsigned, unsigned>
- deduplicateOutputOperands(GenericOp genericOp,
- SmallVector<OpOperand *> &droppedOpOperands,
- SmallVector<Value> &newOutputOperands,
- SmallVector<AffineMap> &newIndexingMaps) const {
- llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
- llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
- dedupedOutpts;
- // If the op doesn't have tensor semantics or outputs should not be removed,
- // keep all the outputs as preserved.
- if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
- for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
- origToNewPos[en.index()] = newOutputOperands.size();
- newOutputOperands.push_back(en.value().get());
- newIndexingMaps.push_back(
- genericOp.getMatchingIndexingMap(&en.value()));
+ // Output argument can be dropped if the result has
+ // - no users, and
+ // - it is not used in the payload, and
+ // - the corresponding indexing maps are not needed for loop bound
+ // computation.
+ auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
+ for (const auto &outputOpOperand :
+ llvm::enumerate(genericOp.getDpsInitsMutable())) {
+ OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
+ AffineMap indexingMap =
+ genericOp.getMatchingIndexingMap(&outputOpOperand.value());
+ auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
+ yieldOp->getOperand(outputOpOperand.index()));
+ if (isResultValueDead(genericOp, result)) {
+ // Check if the opoperand can be dropped without affecting loop
+ // bound computation. Add the operand to the list of dropped op
+ // operand for checking. If it cannot be dropped, need to pop the
+ // value back.
+ droppedOpOperands.push_back(&outputOpOperand.value());
+ if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
+ continue;
}
- return origToNewPos;
+ droppedOpOperands.pop_back();
}
- // Output argument can be dropped if the result has
- // - no users, and
- // - it is not used in the payload, and
- // - the corresponding indexing maps are not needed for loop bound
- // computation.
- auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
- for (const auto &outputOpOperand :
- llvm::enumerate(genericOp.getDpsInitsMutable())) {
- OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
- AffineMap indexingMap =
- genericOp.getMatchingIndexingMap(&outputOpOperand.value());
- auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
- yieldOp->getOperand(outputOpOperand.index()));
- if (isResultValueDead(genericOp, result)) {
- // Check if the opoperand can be dropped without affecting loop
- // bound computation. Add the operand to the list of dropped op
- // operand for checking. If it cannot be dropped, need to pop the
- // value back.
+
+ if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
+ // The out operand can also be dropped if it is computed redundantly
+ // by another result, the conditions for that are
+ // - The same operand is used as the out operand
+ // - The same indexing map is used
+ // - The same yield value is used.
+ auto it = dedupedOutpts.find(key);
+ if (it != dedupedOutpts.end()) {
+ origToNewPos[outputOpOperand.index()] = it->second;
droppedOpOperands.push_back(&outputOpOperand.value());
- if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
- continue;
- }
- droppedOpOperands.pop_back();
+ continue;
}
+ }
- if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
- // The out operand can also be dropped if it is computed redundantly
- // by another result, the conditions for that are
- // - The same operand is used as the out operand
- // - The same indexing map is used
- // - The same yield value is used.
- auto it = dedupedOutpts.find(key);
- if (it != dedupedOutpts.end()) {
- origToNewPos[outputOpOperand.index()] = it->second;
- droppedOpOperands.push_back(&outputOpOperand.value());
- continue;
- }
- }
+ origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
+ dedupedOutpts[key] = newOutputOperands.size();
+ newOutputOperands.push_back(outputOpOperand.value().get());
+ newIndexingMaps.push_back(
+ genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
+ }
+ return origToNewPos;
+}
- origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
- dedupedOutpts[key] = newOutputOperands.size();
- newOutputOperands.push_back(outputOpOperand.value().get());
- newIndexingMaps.push_back(
- genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
+// Populate the body of the canonicalized operation.
+static void populateOpPayload(
+ GenericOp genericOp, GenericOp newOp,
+ const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
+ const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
+ RewriterBase &rewriter) {
+ // Merge the body of the original op with the new op.
+ Block *newOpBlock = &newOp.getRegion().front();
+ assert(newOpBlock->empty() && "expected new op to have an empty payload");
+ Block *origOpBlock = &genericOp.getRegion().front();
+ SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
+
+ // Replace all arguments in the original op, with arguments from the
+ // canonicalized op.
+ auto updateReplacements =
+ [&](SmallVector<OpOperand *> &origOperands,
+ SmallVector<OpOperand *> &newOperands,
+ const llvm::SmallDenseMap<unsigned, unsigned> &map) {
+ for (const auto &origOperand : llvm::enumerate(origOperands)) {
+ auto it = map.find(origOperand.index());
+ if (it == map.end())
+ continue;
+ OpOperand *newOperand = newOperands[it->second];
+ replacements[origOperand.value()->getOperandNumber()] =
+ newOpBlock->getArgument(newOperand->getOperandNumber());
+ }
+ };
+
+ SmallVector<OpOperand *> origInputOperands = genericOp.getDpsInputOperands();
+ SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
+ updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
+
+ SmallVector<OpOperand *> origOutputOperands = llvm::to_vector(llvm::map_range(
+ genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
+ SmallVector<OpOperand *> newOutputOperands = llvm::to_vector(llvm::map_range(
+ newOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
+ updateReplacements(origOutputOperands, newOutputOperands,
+ origOutsToNewOutsPos);
+
+ // Drop the unused yield args.
+ if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
+ OpBuilder::InsertionGuard g(rewriter);
+ YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
+ rewriter.setInsertionPoint(origYieldOp);
+
+ SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
+ for (const auto &yieldOpOperands :
+ llvm::enumerate(origYieldOp.getValues())) {
+ auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
+ if (it == origOutsToNewOutsPos.end())
+ continue;
+ newYieldVals[it->second] = yieldOpOperands.value();
}
- return origToNewPos;
+ rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
}
- // Populate the body of the canonicalized operation.
- void populateOpPayload(
- GenericOp genericOp, GenericOp newOp,
- const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
- const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
- PatternRewriter &rewriter) const {
- // Merge the body of the original op with the new op.
- Block *newOpBlock = &newOp.getRegion().front();
- assert(newOpBlock->empty() && "expected new op to have an empty payload");
- Block *origOpBlock = &genericOp.getRegion().front();
- SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
-
- // Replace all arguments in the original op, with arguments from the
- // canonicalized op.
- auto updateReplacements =
- [&](SmallVector<OpOperand *> &origOperands,
- SmallVector<OpOperand *> &newOperands,
- const llvm::SmallDenseMap<unsigned, unsigned> &map) {
- for (const auto &origOperand : llvm::enumerate(origOperands)) {
- auto it = map.find(origOperand.index());
- if (it == map.end())
- continue;
- OpOperand *newOperand = newOperands[it->second];
- replacements[origOperand.value()->getOperandNumber()] =
- newOpBlock->getArgument(newOperand->getOperandNumber());
- }
- };
-
- SmallVector<OpOperand *> origInputOperands =
- genericOp.getDpsInputOperands();
- SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
- updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
-
- SmallVector<OpOperand *> origOutputOperands =
- llvm::to_vector(llvm::map_range(genericOp.getDpsInitsMutable(),
- [](OpOperand &o) { return &o; }));
- SmallVector<OpOperand *> newOutputOperands =
- llvm::to_vector(llvm::map_range(newOp.getDpsInitsMutable(),
- [](OpOperand &o) { return &o; }));
- updateReplacements(origOutputOperands, newOutputOperands,
- origOutsToNewOutsPos);
-
- // Drop the unused yield args.
- if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
- OpBuilder::InsertionGuard g(rewriter);
- YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
- rewriter.setInsertionPoint(origYieldOp);
-
- SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
- for (const auto &yieldOpOperands :
- llvm::enumerate(origYieldOp.getValues())) {
- auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
- if (it == origOutsToNewOutsPos.end())
- continue;
- newYieldVals[it->second] = yieldOpOperands.value();
- }
- rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
- }
+ rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
+}
----------------
MaheshRavishankar wrote:
Is there a need for that. Its already static....
https://github.com/llvm/llvm-project/pull/125141
More information about the Mlir-commits
mailing list