[Mlir-commits] [mlir] 4e02eb8 - [mlir] Optimize the implementation of RegionDCE
River Riddle
llvmlistbot at llvm.org
Wed Mar 10 16:45:42 PST 2021
Author: River Riddle
Date: 2021-03-10T16:39:50-08:00
New Revision: 4e02eb8014c4dd8dd21071947525926bbe8046ef
URL: https://github.com/llvm/llvm-project/commit/4e02eb8014c4dd8dd21071947525926bbe8046ef
DIFF: https://github.com/llvm/llvm-project/commit/4e02eb8014c4dd8dd21071947525926bbe8046ef.diff
LOG: [mlir] Optimize the implementation of RegionDCE
The current implementation has some inefficiencies that become noticeable when running on large modules. This revision optimizes the code, and updates some out-dated idioms with newer utilities. The main components of this optimization include:
* Add an overload of Block::eraseArguments that allows for O(N) erasure of disjoint arguments.
* Don't process entry block arguments given that we don't erase them at this point.
* Don't track individual operation results, given that we don't erase them. We can just track the parent operation.
Differential Revision: https://reviews.llvm.org/D98309
Added:
Modified:
mlir/include/mlir/IR/Block.h
mlir/lib/IR/Block.cpp
mlir/lib/Transforms/Utils/RegionUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 9f26155b5265..f5436ef49d3c 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -108,7 +108,10 @@ class Block : public IRObjectWithUseList<BlockOperand>,
void eraseArguments(ArrayRef<unsigned> argIndices);
/// Erases the arguments that have their corresponding bit set in
/// `eraseIndices` and removes them from the argument list.
- void eraseArguments(llvm::BitVector eraseIndices);
+ void eraseArguments(const llvm::BitVector &eraseIndices);
+ /// Erases arguments using the given predicate. If the predicate returns true,
+ /// that argument is erased.
+ void eraseArguments(function_ref<bool(BlockArgument)> shouldEraseFn);
unsigned getNumArguments() { return arguments.size(); }
BlockArgument getArgument(unsigned i) { return arguments[i]; }
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 53797aa49fd2..07e2e5c007fe 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -188,23 +188,32 @@ void Block::eraseArguments(ArrayRef<unsigned> argIndices) {
eraseArguments(eraseIndices);
}
-void Block::eraseArguments(llvm::BitVector eraseIndices) {
- // We do this in reverse so that we erase later indices before earlier
- // indices, to avoid shifting the later indices.
- unsigned originalNumArgs = getNumArguments();
- int64_t firstErased = originalNumArgs;
- for (unsigned i = 0; i < originalNumArgs; ++i) {
- int64_t currentPos = originalNumArgs - i - 1;
- if (eraseIndices.test(currentPos)) {
- arguments[currentPos].destroy();
- arguments.erase(arguments.begin() + currentPos);
- firstErased = currentPos;
+void Block::eraseArguments(const llvm::BitVector &eraseIndices) {
+ eraseArguments(
+ [&](BlockArgument arg) { return eraseIndices.test(arg.getArgNumber()); });
+}
+
+void Block::eraseArguments(function_ref<bool(BlockArgument)> shouldEraseFn) {
+ auto firstDead = llvm::find_if(arguments, shouldEraseFn);
+ if (firstDead == arguments.end())
+ return;
+
+ // Destroy the first dead argument, this avoids reapplying the predicate to
+ // it.
+ unsigned index = firstDead->getArgNumber();
+ firstDead->destroy();
+
+ // Iterate the remaining arguments to remove any that are now dead.
+ for (auto it = std::next(firstDead), e = arguments.end(); it != e; ++it) {
+ // Destroy dead arguments, and shift those that are still live.
+ if (shouldEraseFn(*it)) {
+ it->destroy();
+ } else {
+ it->setArgNumber(index++);
+ *firstDead++ = *it;
}
}
- // Update the cached position for the arguments after the first erased one.
- int64_t index = firstErased;
- for (BlockArgument arg : llvm::drop_begin(arguments, index))
- arg.setArgNumber(index++);
+ arguments.erase(firstDead, arguments.end());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 125d32b782c9..7dd064ef0341 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -139,9 +139,23 @@ namespace {
class LiveMap {
public:
/// Value methods.
- bool wasProvenLive(Value value) { return liveValues.count(value); }
+ bool wasProvenLive(Value value) {
+ // TODO: For results that are removable, e.g. for region based control flow,
+ // we could allow for these values to be tracked independently.
+ if (OpResult result = value.dyn_cast<OpResult>())
+ return wasProvenLive(result.getOwner());
+ return wasProvenLive(value.cast<BlockArgument>());
+ }
+ bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
void setProvedLive(Value value) {
- changed |= liveValues.insert(value).second;
+ // TODO: For results that are removable, e.g. for region based control flow,
+ // we could allow for these values to be tracked independently.
+ if (OpResult result = value.dyn_cast<OpResult>())
+ return setProvedLive(result.getOwner());
+ setProvedLive(value.cast<BlockArgument>());
+ }
+ void setProvedLive(BlockArgument arg) {
+ changed |= liveValues.insert(arg).second;
}
/// Operation methods.
@@ -192,15 +206,6 @@ static void processValue(Value value, LiveMap &liveMap) {
liveMap.setProvedLive(value);
}
-static bool isOpIntrinsicallyLive(Operation *op) {
- // This pass doesn't modify the CFG, so terminators are never deleted.
- if (op->mightHaveTrait<OpTrait::IsTerminator>())
- return true;
- // If the op has a side effect, we treat it as live.
- // TODO: Properly handle region side effects.
- return !MemoryEffectOpInterface::hasNoEffect(op) || op->getNumRegions() != 0;
-}
-
static void propagateLiveness(Region ®ion, LiveMap &liveMap);
static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
@@ -226,9 +231,6 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
}
static void propagateLiveness(Operation *op, LiveMap &liveMap) {
- // All Value's are either a block argument or an op result.
- // We call processValue on those cases.
-
// Recurse on any regions the op has.
for (Region ®ion : op->getRegions())
propagateLiveness(region, liveMap);
@@ -237,18 +239,17 @@ static void propagateLiveness(Operation *op, LiveMap &liveMap) {
if (op->hasTrait<OpTrait::IsTerminator>())
return propagateTerminatorLiveness(op, liveMap);
- // Process the op itself.
- if (isOpIntrinsicallyLive(op)) {
- liveMap.setProvedLive(op);
+ // Don't reprocess live operations.
+ if (liveMap.wasProvenLive(op))
return;
- }
+
+ // Process the op itself.
+ if (!wouldOpBeTriviallyDead(op))
+ return liveMap.setProvedLive(op);
+
+ // If the op isn't intrinsically alive, check it's results.
for (Value value : op->getResults())
processValue(value, liveMap);
- bool provedLive = llvm::any_of(op->getResults(), [&](Value value) {
- return liveMap.wasProvenLive(value);
- });
- if (provedLive)
- liveMap.setProvedLive(op);
}
static void propagateLiveness(Region ®ion, LiveMap &liveMap) {
@@ -260,8 +261,18 @@ static void propagateLiveness(Region ®ion, LiveMap &liveMap) {
// faster convergence to a fixed point (we try to visit uses before defs).
for (Operation &op : llvm::reverse(block->getOperations()))
propagateLiveness(&op, liveMap);
- for (Value value : block->getArguments())
- processValue(value, liveMap);
+
+ // We currently do not remove entry block arguments, so there is no need to
+ // track their liveness.
+ // TODO: We could track these and enable removing dead operands/arguments
+ // from region control flow operations.
+ if (block->isEntryBlock())
+ continue;
+
+ for (Value value : block->getArguments()) {
+ if (!liveMap.wasProvenLive(value))
+ processValue(value, liveMap);
+ }
}
}
@@ -314,11 +325,12 @@ static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
for (Operation &childOp :
llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
- erasedAnything |=
- succeeded(deleteDeadness(childOp.getRegions(), liveMap));
if (!liveMap.wasProvenLive(&childOp)) {
erasedAnything = true;
childOp.erase();
+ } else {
+ erasedAnything |=
+ succeeded(deleteDeadness(childOp.getRegions(), liveMap));
}
}
}
@@ -326,13 +338,8 @@ static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
// The entry block has an unknown contract with their enclosing block, so
// skip it.
for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
- // Iterate in reverse to avoid shifting later arguments when deleting
- // earlier arguments.
- for (unsigned i = 0, e = block.getNumArguments(); i < e; i++)
- if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) {
- block.eraseArgument(e - i - 1);
- erasedAnything = true;
- }
+ block.eraseArguments(
+ [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
}
}
return success(erasedAnything);
More information about the Mlir-commits
mailing list