[Mlir-commits] [mlir] b62060a - [mlir][Linalg] NFC: Refactor canonicalization for deduping generic op operands.
Mahesh Ravishankar
llvmlistbot at llvm.org
Fri Jul 15 12:47:59 PDT 2022
Author: Mahesh Ravishankar
Date: 2022-07-15T19:47:45Z
New Revision: b62060a8e330e9cc7b537b50d797344666ae51fc
URL: https://github.com/llvm/llvm-project/commit/b62060a8e330e9cc7b537b50d797344666ae51fc
DIFF: https://github.com/llvm/llvm-project/commit/b62060a8e330e9cc7b537b50d797344666ae51fc.diff
LOG: [mlir][Linalg] NFC: Refactor canonicalization for deduping generic op operands.
This is a NFC change to make it easier to update this canonicalization
for more use cases. The refactoring makes things easier to
understand/adapt.
Differential Revision: https://reviews.llvm.org/D129829
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1ce7d4dab1f13..928a8039a43e8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -860,67 +860,128 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
PatternRewriter &rewriter) const override {
// Create a map from argument position in the original op to the argument
// position in the new op. If the argument is dropped it wont have an entry.
- llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
- unsigned numNewArgs = 0;
SmallVector<OpOperand *> droppedOpOperands;
- llvm::SmallDenseSet<unsigned> droppedOutputs;
// Information needed to build the new op.
SmallVector<Value> newInputOperands, newOutputOperands;
SmallVector<AffineMap> newIndexingMaps;
+
+ // Gather information about duplicate input operands.
+ llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
+ deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
+ newIndexingMaps);
+
+ // Gather information about the dropped outputs.
+ llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
+ deduplicateOutputOperands(genericOp, droppedOpOperands,
+ newOutputOperands, newIndexingMaps);
+
+ // Check if there is any change to operands.
+ if (newInputOperands.size() + newOutputOperands.size() ==
+ static_cast<size_t>(genericOp.getNumInputsAndOutputs()))
+ return failure();
+
+ // Create the new op with the body being empty.
+ Location loc = genericOp.getLoc();
SmallVector<Type> newResultTypes;
+ if (genericOp.hasTensorSemantics()) {
+ newResultTypes = llvm::to_vector(llvm::map_range(
+ newOutputOperands, [](Value v) { return v.getType(); }));
+ }
+ auto newOp = rewriter.create<GenericOp>(
+ loc, newResultTypes, newInputOperands, newOutputOperands,
+ rewriter.getAffineMapArrayAttr(newIndexingMaps),
+ genericOp.iterator_types(), genericOp.docAttr(),
+ genericOp.library_callAttr(),
+ [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
+ return;
+ });
+ // Copy over unknown attributes. They might be load bearing for some flow.
+ ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
+ for (NamedAttribute kv : genericOp->getAttrs())
+ if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
+ newOp->setAttr(kv.getName(), kv.getValue());
- // Input argument can be dropped if
- // - it has no uses, or,
- // - there is a duplicate operand which is accessed using the same
- // indexing map.
- llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
- auto indexingMaps = genericOp.getIndexingMaps();
- ArrayRef<AffineMap> unprocessedIndexingMaps(indexingMaps);
- for (OpOperand *inputOpOperand : genericOp.getInputOperands()) {
- BlockArgument arg = genericOp.getTiedBlockArgument(inputOpOperand);
- unsigned argNum = arg.getArgNumber();
- unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front();
+ // Fix up the payload of the canonicalized operation.
+ populateOpPayload(genericOp, newOp, origInsToNewInsPos,
+ origOutsToNewOutsPos, rewriter);
+ // Replace all live uses of the op.
+ SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
+ for (auto result : llvm::enumerate(genericOp.getResults())) {
+ auto it = origOutsToNewOutsPos.find(result.index());
+ if (it == origOutsToNewOutsPos.end())
+ continue;
+ replacementsVals[result.index()] = newOp.getResult(it->second);
+ }
+ rewriter.replaceOp(genericOp, replacementsVals);
+ return success();
+ }
+
+private:
+ // Deduplicate input operands, and return the
+ // - Mapping from operand position in the original op, to operand position in
+ // the canonicalized op.
+ // - The preserved input operands list (by reference).
+ llvm::SmallDenseMap<unsigned, unsigned>
+ deduplicateInputOperands(GenericOp genericOp,
+ SmallVector<OpOperand *> &droppedOpOperands,
+ SmallVector<Value> &newInputOperands,
+ SmallVector<AffineMap> &newIndexingMaps) const {
+ llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
+ llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
+ for (auto inputOpOperand : llvm::enumerate(genericOp.getInputOperands())) {
// Check if operand is dead and if dropping the indexing map makes the
// loops to shape computation invalid.
- if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
+ if (!genericOp.payloadUsesValueFromOperand(inputOpOperand.value())) {
// Add the current operands to the list of potentially droppable
// operands. If it cannot be dropped, this needs to be popped back.
- droppedOpOperands.push_back(inputOpOperand);
+ droppedOpOperands.push_back(inputOpOperand.value());
if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
continue;
droppedOpOperands.pop_back();
}
// Check if this operand is a duplicate.
- AffineMap indexingMap = genericOp.getTiedIndexingMap(inputOpOperand);
+ AffineMap indexingMap =
+ genericOp.getTiedIndexingMap(inputOpOperand.value());
auto it = dedupedInputs.find(
- std::make_pair(inputOpOperand->get(), indexingMap));
+ std::make_pair(inputOpOperand.value()->get(), indexingMap));
if (it != dedupedInputs.end()) {
- origToNewPos[argNum] = it->second;
- droppedOpOperands.push_back(inputOpOperand);
+ origToNewPos[inputOpOperand.index()] = it->second;
+ droppedOpOperands.push_back(inputOpOperand.value());
continue;
}
// This is a preserved argument.
- origToNewPos[argNum] = numNewArgs;
- dedupedInputs[{inputOpOperand->get(), indexingMap}] = numNewArgs;
- newInputOperands.push_back(inputOpOperand->get());
+ origToNewPos[inputOpOperand.index()] = newInputOperands.size();
+ dedupedInputs[{inputOpOperand.value()->get(), indexingMap}] =
+ newInputOperands.size();
+ newInputOperands.push_back(inputOpOperand.value()->get());
newIndexingMaps.push_back(indexingMap);
- numNewArgs++;
}
+ return origToNewPos;
+ }
+ // Deduplicate output operands, and return the
+ // - Mapping from operand position in the original op, to operand position in
+ // the canonicalized op.
+ // - The preserved output operands list (by reference).
+ llvm::SmallDenseMap<unsigned, unsigned>
+ deduplicateOutputOperands(GenericOp genericOp,
+ SmallVector<OpOperand *> &droppedOpOperands,
+ SmallVector<Value> &newOutputOperands,
+ SmallVector<AffineMap> &newIndexingMaps) const {
+ llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
// If the op doesnt have tensor semantics, keep all the outputs as
// preserved.
if (!genericOp.hasTensorSemantics()) {
- for (OpOperand *outputOpOperand : genericOp.getOutputOperands()) {
- unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front();
- BlockArgument arg = genericOp.getTiedBlockArgument(outputOpOperand);
- origToNewPos[arg.getArgNumber()] = numNewArgs++;
- newOutputOperands.push_back(outputOpOperand->get());
+ for (auto outputOpOperand :
+ llvm::enumerate(genericOp.getOutputOperands())) {
+ origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
+ newOutputOperands.push_back(outputOpOperand.value()->get());
newIndexingMaps.push_back(
- genericOp.getTiedIndexingMap(outputOpOperand));
+ genericOp.getTiedIndexingMap(outputOpOperand.value()));
}
} else {
// Output argument can be dropped if the result has
@@ -928,12 +989,9 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
// - it is not used in the payload, and
// - the corresponding indexing maps are not needed for loop bound
// computation.
- for (const auto &outputOpOperand :
+ for (auto outputOpOperand :
llvm::enumerate(genericOp.getOutputOperands())) {
- unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front();
Value result = genericOp.getResult(outputOpOperand.index());
- BlockArgument arg =
- genericOp.getTiedBlockArgument(outputOpOperand.value());
if (result.use_empty() &&
!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
// Check if the opoperand can be dropped without affecting loop bound
@@ -941,77 +999,75 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
// checking. If it cannot be dropped, need to pop the value back.
droppedOpOperands.push_back(outputOpOperand.value());
if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
- droppedOutputs.insert(outputOpOperand.index());
continue;
}
droppedOpOperands.pop_back();
}
- origToNewPos[arg.getArgNumber()] = numNewArgs++;
+ origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
newOutputOperands.push_back(outputOpOperand.value()->get());
newIndexingMaps.push_back(
genericOp.getTiedIndexingMap(outputOpOperand.value()));
- newResultTypes.push_back(result.getType());
}
}
- // Check if there is any change to operands.
- if (newInputOperands.size() + newOutputOperands.size() ==
- static_cast<size_t>(genericOp.getNumInputsAndOutputs()))
- return failure();
-
- // Create the new op with the body being empty.
- Location loc = genericOp.getLoc();
- auto newOp = rewriter.create<GenericOp>(
- loc, newResultTypes, newInputOperands, newOutputOperands,
- rewriter.getAffineMapArrayAttr(newIndexingMaps),
- genericOp.iterator_types(), genericOp.docAttr(),
- genericOp.library_callAttr(),
- [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
- return;
- });
- // Copy over unknown attributes. They might be load bearing for some flow.
- ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
- for (NamedAttribute kv : genericOp->getAttrs())
- if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
- newOp->setAttr(kv.getName(), kv.getValue());
+ return origToNewPos;
+ }
+ // Populate the body of the canonicalized operation.
+ void populateOpPayload(
+ GenericOp genericOp, GenericOp newOp,
+ const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
+ const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
+ PatternRewriter &rewriter) const {
// Merge the body of the original op with the new op.
Block *newOpBlock = &newOp.region().front();
+ assert(newOpBlock->empty() && "expected new op to have an empty payload");
Block *origOpBlock = &genericOp.region().front();
SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
- for (auto argNum : llvm::seq<unsigned>(0, origOpBlock->getNumArguments())) {
- auto it = origToNewPos.find(argNum);
- if (it != origToNewPos.end())
- replacements[argNum] = newOpBlock->getArgument(it->second);
- }
+
+ // Replace all arguments in the original op, with arguments from the
+ // canonicalized op.
+ auto updateReplacements =
+ [&](OpOperandVector &origOperands, OpOperandVector &newOperands,
+ const llvm::SmallDenseMap<unsigned, unsigned> &map) {
+ for (auto origOperand : llvm::enumerate(origOperands)) {
+ auto it = map.find(origOperand.index());
+ if (it == map.end())
+ continue;
+ OpOperand *newOperand = newOperands[it->second];
+ replacements[origOperand.value()->getOperandNumber()] =
+ newOpBlock->getArgument(newOperand->getOperandNumber());
+ }
+ };
+
+ OpOperandVector origInputOperands = genericOp.getInputOperands();
+ OpOperandVector newInputOperands = newOp.getInputOperands();
+ updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
+
+ OpOperandVector origOutputOperands = genericOp.getOutputOperands();
+ OpOperandVector newOutputOperands = newOp.getOutputOperands();
+ updateReplacements(origOutputOperands, newOutputOperands,
+ origOutsToNewOutsPos);
+
rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
// Drop the unused yield args.
- Block *block = &newOp.region().front();
- if (!droppedOutputs.empty()) {
+ if (newOp.getNumOutputs() != genericOp.getNumOutputs()) {
OpBuilder::InsertionGuard g(rewriter);
- SmallVector<Value> newYieldVals;
- YieldOp origYieldOp = cast<YieldOp>(block->getTerminator());
+ YieldOp origYieldOp = cast<YieldOp>(newOpBlock->getTerminator());
rewriter.setInsertionPoint(origYieldOp);
+
+ SmallVector<Value> newYieldVals(newOp.getNumOutputs(), nullptr);
for (const auto &yieldOpOperands :
llvm::enumerate(origYieldOp.values())) {
- if (!droppedOutputs.count(yieldOpOperands.index())) {
- newYieldVals.push_back(yieldOpOperands.value());
+ auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
+ if (it == origOutsToNewOutsPos.end())
continue;
- }
+ newYieldVals[it->second] = yieldOpOperands.value();
}
rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
}
-
- // Replace all live uses of the op.
- SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
- unsigned newResultNum = 0;
- for (const auto &result : llvm::enumerate(genericOp.getResults()))
- if (!droppedOutputs.count(result.index()))
- replacementsVals[result.index()] = newOp.getResult(newResultNum++);
- rewriter.replaceOp(genericOp, replacementsVals);
- return success();
}
};
More information about the Mlir-commits
mailing list