[Mlir-commits] [mlir] 4e02eb8 - [mlir] Optimize the implementation of RegionDCE

River Riddle llvmlistbot at llvm.org
Wed Mar 10 16:45:42 PST 2021


Author: River Riddle
Date: 2021-03-10T16:39:50-08:00
New Revision: 4e02eb8014c4dd8dd21071947525926bbe8046ef

URL: https://github.com/llvm/llvm-project/commit/4e02eb8014c4dd8dd21071947525926bbe8046ef
DIFF: https://github.com/llvm/llvm-project/commit/4e02eb8014c4dd8dd21071947525926bbe8046ef.diff

LOG: [mlir] Optimize the implementation of RegionDCE

The current implementation has some inefficiencies that become noticeable when running on large modules. This revision optimizes the code, and updates some out-dated idioms with newer utilities. The main components of this optimization include:

* Add an overload of Block::eraseArguments that allows for O(N) erasure of disjoint arguments.
* Don't process entry block arguments given that we don't erase them at this point.
* Don't track individual operation results, given that we don't erase them. We can just track the parent operation.

Differential Revision: https://reviews.llvm.org/D98309

Added: 
    

Modified: 
    mlir/include/mlir/IR/Block.h
    mlir/lib/IR/Block.cpp
    mlir/lib/Transforms/Utils/RegionUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 9f26155b5265..f5436ef49d3c 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -108,7 +108,10 @@ class Block : public IRObjectWithUseList<BlockOperand>,
   void eraseArguments(ArrayRef<unsigned> argIndices);
   /// Erases the arguments that have their corresponding bit set in
   /// `eraseIndices` and removes them from the argument list.
-  void eraseArguments(llvm::BitVector eraseIndices);
+  void eraseArguments(const llvm::BitVector &eraseIndices);
+  /// Erases arguments using the given predicate. If the predicate returns true,
+  /// that argument is erased.
+  void eraseArguments(function_ref<bool(BlockArgument)> shouldEraseFn);
 
   unsigned getNumArguments() { return arguments.size(); }
   BlockArgument getArgument(unsigned i) { return arguments[i]; }

diff  --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 53797aa49fd2..07e2e5c007fe 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -188,23 +188,32 @@ void Block::eraseArguments(ArrayRef<unsigned> argIndices) {
   eraseArguments(eraseIndices);
 }
 
-void Block::eraseArguments(llvm::BitVector eraseIndices) {
-  // We do this in reverse so that we erase later indices before earlier
-  // indices, to avoid shifting the later indices.
-  unsigned originalNumArgs = getNumArguments();
-  int64_t firstErased = originalNumArgs;
-  for (unsigned i = 0; i < originalNumArgs; ++i) {
-    int64_t currentPos = originalNumArgs - i - 1;
-    if (eraseIndices.test(currentPos)) {
-      arguments[currentPos].destroy();
-      arguments.erase(arguments.begin() + currentPos);
-      firstErased = currentPos;
+void Block::eraseArguments(const llvm::BitVector &eraseIndices) {
+  eraseArguments(
+      [&](BlockArgument arg) { return eraseIndices.test(arg.getArgNumber()); });
+}
+
+void Block::eraseArguments(function_ref<bool(BlockArgument)> shouldEraseFn) {
+  auto firstDead = llvm::find_if(arguments, shouldEraseFn);
+  if (firstDead == arguments.end())
+    return;
+
+  // Destroy the first dead argument, this avoids reapplying the predicate to
+  // it.
+  unsigned index = firstDead->getArgNumber();
+  firstDead->destroy();
+
+  // Iterate the remaining arguments to remove any that are now dead.
+  for (auto it = std::next(firstDead), e = arguments.end(); it != e; ++it) {
+    // Destroy dead arguments, and shift those that are still live.
+    if (shouldEraseFn(*it)) {
+      it->destroy();
+    } else {
+      it->setArgNumber(index++);
+      *firstDead++ = *it;
     }
   }
-  // Update the cached position for the arguments after the first erased one.
-  int64_t index = firstErased;
-  for (BlockArgument arg : llvm::drop_begin(arguments, index))
-    arg.setArgNumber(index++);
+  arguments.erase(firstDead, arguments.end());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 125d32b782c9..7dd064ef0341 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -139,9 +139,23 @@ namespace {
 class LiveMap {
 public:
   /// Value methods.
-  bool wasProvenLive(Value value) { return liveValues.count(value); }
+  bool wasProvenLive(Value value) {
+    // TODO: For results that are removable, e.g. for region based control flow,
+    // we could allow for these values to be tracked independently.
+    if (OpResult result = value.dyn_cast<OpResult>())
+      return wasProvenLive(result.getOwner());
+    return wasProvenLive(value.cast<BlockArgument>());
+  }
+  bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
   void setProvedLive(Value value) {
-    changed |= liveValues.insert(value).second;
+    // TODO: For results that are removable, e.g. for region based control flow,
+    // we could allow for these values to be tracked independently.
+    if (OpResult result = value.dyn_cast<OpResult>())
+      return setProvedLive(result.getOwner());
+    setProvedLive(value.cast<BlockArgument>());
+  }
+  void setProvedLive(BlockArgument arg) {
+    changed |= liveValues.insert(arg).second;
   }
 
   /// Operation methods.
@@ -192,15 +206,6 @@ static void processValue(Value value, LiveMap &liveMap) {
     liveMap.setProvedLive(value);
 }
 
-static bool isOpIntrinsicallyLive(Operation *op) {
-  // This pass doesn't modify the CFG, so terminators are never deleted.
-  if (op->mightHaveTrait<OpTrait::IsTerminator>())
-    return true;
-  // If the op has a side effect, we treat it as live.
-  // TODO: Properly handle region side effects.
-  return !MemoryEffectOpInterface::hasNoEffect(op) || op->getNumRegions() != 0;
-}
-
 static void propagateLiveness(Region &region, LiveMap &liveMap);
 
 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
@@ -226,9 +231,6 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
 }
 
 static void propagateLiveness(Operation *op, LiveMap &liveMap) {
-  // All Value's are either a block argument or an op result.
-  // We call processValue on those cases.
-
   // Recurse on any regions the op has.
   for (Region &region : op->getRegions())
     propagateLiveness(region, liveMap);
@@ -237,18 +239,17 @@ static void propagateLiveness(Operation *op, LiveMap &liveMap) {
   if (op->hasTrait<OpTrait::IsTerminator>())
     return propagateTerminatorLiveness(op, liveMap);
 
-  // Process the op itself.
-  if (isOpIntrinsicallyLive(op)) {
-    liveMap.setProvedLive(op);
+  // Don't reprocess live operations.
+  if (liveMap.wasProvenLive(op))
     return;
-  }
+
+  // Process the op itself.
+  if (!wouldOpBeTriviallyDead(op))
+    return liveMap.setProvedLive(op);
+
+  // If the op isn't intrinsically alive, check it's results.
   for (Value value : op->getResults())
     processValue(value, liveMap);
-  bool provedLive = llvm::any_of(op->getResults(), [&](Value value) {
-    return liveMap.wasProvenLive(value);
-  });
-  if (provedLive)
-    liveMap.setProvedLive(op);
 }
 
 static void propagateLiveness(Region &region, LiveMap &liveMap) {
@@ -260,8 +261,18 @@ static void propagateLiveness(Region &region, LiveMap &liveMap) {
     // faster convergence to a fixed point (we try to visit uses before defs).
     for (Operation &op : llvm::reverse(block->getOperations()))
       propagateLiveness(&op, liveMap);
-    for (Value value : block->getArguments())
-      processValue(value, liveMap);
+
+    // We currently do not remove entry block arguments, so there is no need to
+    // track their liveness.
+    // TODO: We could track these and enable removing dead operands/arguments
+    // from region control flow operations.
+    if (block->isEntryBlock())
+      continue;
+
+    for (Value value : block->getArguments()) {
+      if (!liveMap.wasProvenLive(value))
+        processValue(value, liveMap);
+    }
   }
 }
 
@@ -314,11 +325,12 @@ static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
       eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
       for (Operation &childOp :
            llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
-        erasedAnything |=
-            succeeded(deleteDeadness(childOp.getRegions(), liveMap));
         if (!liveMap.wasProvenLive(&childOp)) {
           erasedAnything = true;
           childOp.erase();
+        } else {
+          erasedAnything |=
+              succeeded(deleteDeadness(childOp.getRegions(), liveMap));
         }
       }
     }
@@ -326,13 +338,8 @@ static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
     // The entry block has an unknown contract with their enclosing block, so
     // skip it.
     for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
-      // Iterate in reverse to avoid shifting later arguments when deleting
-      // earlier arguments.
-      for (unsigned i = 0, e = block.getNumArguments(); i < e; i++)
-        if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) {
-          block.eraseArgument(e - i - 1);
-          erasedAnything = true;
-        }
+      block.eraseArguments(
+          [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
     }
   }
   return success(erasedAnything);


        


More information about the Mlir-commits mailing list