[llvm] 05ae04c - [DA][SDA] SyncDependenceAnalysis re-write

Simon Moll via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 30 08:37:11 PDT 2020


Author: Simon Moll
Date: 2020-09-30T17:36:26+02:00
New Revision: 05ae04c396519cca9ef50d3b9cafb0cd9c87d1d7

URL: https://github.com/llvm/llvm-project/commit/05ae04c396519cca9ef50d3b9cafb0cd9c87d1d7
DIFF: https://github.com/llvm/llvm-project/commit/05ae04c396519cca9ef50d3b9cafb0cd9c87d1d7.diff

LOG: [DA][SDA] SyncDependenceAnalysis re-write

This patch achieves two things:
1. It breaks up the `join_blocks` interface between the SDA to the DA to
   return two separate sets for divergent loops exits and divergent,
disjoint path joins.
2. It updates the SDA algorithm to run in O(n) time and improves the
   precision on divergent loop exits.

This fixes `https://bugs.llvm.org/show_bug.cgi?id=46372` (by virtue of
the improved `join_blocks` interface) and revealed an imprecise expected
result in the `Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll`
test.

Reviewed By: sameerds

Differential Revision: https://reviews.llvm.org/D84413

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/DivergenceAnalysis.h
    llvm/include/llvm/Analysis/SyncDependenceAnalysis.h
    llvm/lib/Analysis/DivergenceAnalysis.cpp
    llvm/lib/Analysis/SyncDependenceAnalysis.cpp
    llvm/test/Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll
    llvm/test/Analysis/DivergenceAnalysis/AMDGPU/trivial-join-at-loop-exit.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/DivergenceAnalysis.h b/llvm/include/llvm/Analysis/DivergenceAnalysis.h
index a2da97bb9059..8a32bfbcc758 100644
--- a/llvm/include/llvm/Analysis/DivergenceAnalysis.h
+++ b/llvm/include/llvm/Analysis/DivergenceAnalysis.h
@@ -59,8 +59,10 @@ class DivergenceAnalysis {
   /// \brief Mark \p UniVal as a value that is always uniform.
   void addUniformOverride(const Value &UniVal);
 
-  /// \brief Mark \p DivVal as a value that is always divergent.
-  void markDivergent(const Value &DivVal);
+  /// \brief Mark \p DivVal as a value that is always divergent. Will not do so
+  /// if `isAlwaysUniform(DivVal)`.
+  /// \returns Whether the tracked divergence state of \p DivVal changed.
+  bool markDivergent(const Value &DivVal);
 
   /// \brief Propagate divergence to all instructions in the region.
   /// Divergence is seeded by calls to \p markDivergent.
@@ -76,45 +78,38 @@ class DivergenceAnalysis {
   /// \brief Whether \p Val is divergent at its definition.
   bool isDivergent(const Value &Val) const;
 
-  /// \brief Whether \p U is divergent. Uses of a uniform value can be divergent.
+  /// \brief Whether \p U is divergent. Uses of a uniform value can be
+  /// divergent.
   bool isDivergentUse(const Use &U) const;
 
   void print(raw_ostream &OS, const Module *) const;
 
 private:
-  bool updateTerminator(const Instruction &Term) const;
-  bool updatePHINode(const PHINode &Phi) const;
-
-  /// \brief Computes whether \p Inst is divergent based on the
-  /// divergence of its operands.
-  ///
-  /// \returns Whether \p Inst is divergent.
-  ///
-  /// This should only be called for non-phi, non-terminator instructions.
-  bool updateNormalInstruction(const Instruction &Inst) const;
-
-  /// \brief Mark users of live-out users as divergent.
-  ///
-  /// \param LoopHeader the header of the divergent loop.
-  ///
-  /// Marks all users of live-out values of the loop headed by \p LoopHeader
-  /// as divergent and puts them on the worklist.
-  void taintLoopLiveOuts(const BasicBlock &LoopHeader);
-
-  /// \brief Push all users of \p Val (in the region) to the worklist
+  /// \brief Mark \p Term as divergent and push all Instructions that become
+  /// divergent as a result on the worklist.
+  void analyzeControlDivergence(const Instruction &Term);
+  /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on
+  /// the worklist.
+  void taintAndPushPhiNodes(const BasicBlock &JoinBlock);
+
+  /// \brief Identify all Instructions that become divergent because \p DivExit
+  /// is a divergent loop exit of \p DivLoop. Mark those instructions as
+  /// divergent and push them on the worklist.
+  void propagateLoopExitDivergence(const BasicBlock &DivExit,
+                                   const Loop &DivLoop);
+
+  /// \brief Internal implementation function for propagateLoopExitDivergence.
+  void analyzeLoopExitDivergence(const BasicBlock &DivExit,
+                                 const Loop &OuterDivLoop);
+
+  /// \brief Mark all instruction as divergent that use a value defined in \p
+  /// OuterDivLoop. Push their users on the worklist.
+  void analyzeTemporalDivergence(const Instruction &I,
+                                 const Loop &OuterDivLoop);
+
+  /// \brief Push all users of \p Val (in the region) to the worklist.
   void pushUsers(const Value &I);
 
-  /// \brief Push all phi nodes in @block to the worklist
-  void pushPHINodes(const BasicBlock &Block);
-
-  /// \brief Mark \p Block as join divergent
-  ///
-  /// A block is join divergent if two threads may reach it from 
diff erent
-  /// incoming blocks at the same time.
-  void markBlockJoinDivergent(const BasicBlock &Block) {
-    DivergentJoinBlocks.insert(&Block);
-  }
-
   /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
   bool isTemporalDivergent(const BasicBlock &ObservingBlock,
                            const Value &Val) const;
@@ -126,24 +121,6 @@ class DivergenceAnalysis {
     return DivergentJoinBlocks.find(&Block) != DivergentJoinBlocks.end();
   }
 
-  /// \brief Propagate control-induced divergence to users (phi nodes and
-  /// instructions).
-  //
-  // \param JoinBlock is a divergent loop exit or join point of two disjoint
-  // paths.
-  // \returns Whether \p JoinBlock is a divergent loop exit of \p TermLoop.
-  bool propagateJoinDivergence(const BasicBlock &JoinBlock,
-                               const Loop *TermLoop);
-
-  /// \brief Propagate induced value divergence due to control divergence in \p
-  /// Term.
-  void propagateBranchDivergence(const Instruction &Term);
-
-  /// \brief Propagate divergent caused by a divergent loop exit.
-  ///
-  /// \param ExitingLoop is a divergent loop.
-  void propagateLoopDivergence(const Loop &ExitingLoop);
-
 private:
   const Function &F;
   // If regionLoop != nullptr, analysis is only performed within \p RegionLoop.
@@ -166,7 +143,7 @@ class DivergenceAnalysis {
   DenseSet<const Value *> UniformOverrides;
 
   // Blocks with joining divergent control from 
diff erent predecessors.
-  DenseSet<const BasicBlock *> DivergentJoinBlocks;
+  DenseSet<const BasicBlock *> DivergentJoinBlocks; // FIXME Deprecated
 
   // Detected/marked divergent values.
   DenseSet<const Value *> DivergentValues;

diff  --git a/llvm/include/llvm/Analysis/SyncDependenceAnalysis.h b/llvm/include/llvm/Analysis/SyncDependenceAnalysis.h
index 2f07b3135308..9838d629e93e 100644
--- a/llvm/include/llvm/Analysis/SyncDependenceAnalysis.h
+++ b/llvm/include/llvm/Analysis/SyncDependenceAnalysis.h
@@ -21,6 +21,7 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include <memory>
+#include <unordered_map>
 
 namespace llvm {
 
@@ -30,6 +31,26 @@ class Loop;
 class PostDominatorTree;
 
 using ConstBlockSet = SmallPtrSet<const BasicBlock *, 4>;
+struct ControlDivergenceDesc {
+  // Join points of divergent disjoint paths.
+  ConstBlockSet JoinDivBlocks;
+  // Divergent loop exits
+  ConstBlockSet LoopDivBlocks;
+};
+
+struct ModifiedPO {
+  std::vector<const BasicBlock *> LoopPO;
+  std::unordered_map<const BasicBlock *, unsigned> POIndex;
+  void appendBlock(const BasicBlock &BB) {
+    POIndex[&BB] = LoopPO.size();
+    LoopPO.push_back(&BB);
+  }
+  unsigned getIndexOf(const BasicBlock &BB) const {
+    return POIndex.find(&BB)->second;
+  }
+  unsigned size() const { return LoopPO.size(); }
+  const BasicBlock *getBlockAt(unsigned Idx) const { return LoopPO[Idx]; }
+};
 
 /// \brief Relates points of divergent control to join points in
 /// reducible CFGs.
@@ -51,28 +72,19 @@ class SyncDependenceAnalysis {
   /// header. Those exit blocks are added to the returned set.
   /// If L is the parent loop of \p Term and an exit of L is in the returned
   /// set then L is a divergent loop.
-  const ConstBlockSet &join_blocks(const Instruction &Term);
-
-  /// \brief Computes divergent join points and loop exits (in the surrounding
-  /// loop) caused by the divergent loop exits of\p Loop.
-  ///
-  /// The set of blocks which are reachable by disjoint paths from the
-  /// loop exits of \p Loop.
-  /// This treats the loop as a single node in \p Loop's parent loop.
-  /// The returned set has the same properties as for join_blocks(TermInst&).
-  const ConstBlockSet &join_blocks(const Loop &Loop);
+  const ControlDivergenceDesc &getJoinBlocks(const Instruction &Term);
 
 private:
-  static ConstBlockSet EmptyBlockSet;
+  static ControlDivergenceDesc EmptyDivergenceDesc;
+
+  ModifiedPO LoopPO;
 
-  ReversePostOrderTraversal<const Function *> FuncRPOT;
   const DominatorTree &DT;
   const PostDominatorTree &PDT;
   const LoopInfo &LI;
 
-  std::map<const Loop *, std::unique_ptr<ConstBlockSet>> CachedLoopExitJoins;
-  std::map<const Instruction *, std::unique_ptr<ConstBlockSet>>
-      CachedBranchJoins;
+  std::map<const Instruction *, std::unique_ptr<ControlDivergenceDesc>>
+      CachedControlDivDescs;
 };
 
 } // namespace llvm

diff  --git a/llvm/lib/Analysis/DivergenceAnalysis.cpp b/llvm/lib/Analysis/DivergenceAnalysis.cpp
index 343406c9bba1..d01a0b95612c 100644
--- a/llvm/lib/Analysis/DivergenceAnalysis.cpp
+++ b/llvm/lib/Analysis/DivergenceAnalysis.cpp
@@ -1,4 +1,4 @@
-//===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==//
+//===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -97,42 +97,18 @@ DivergenceAnalysis::DivergenceAnalysis(
     : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA),
       IsLCSSAForm(IsLCSSAForm) {}
 
-void DivergenceAnalysis::markDivergent(const Value &DivVal) {
+bool DivergenceAnalysis::markDivergent(const Value &DivVal) {
+  if (isAlwaysUniform(DivVal))
+    return false;
   assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal));
   assert(!isAlwaysUniform(DivVal) && "cannot be a divergent");
-  DivergentValues.insert(&DivVal);
+  return DivergentValues.insert(&DivVal).second;
 }
 
 void DivergenceAnalysis::addUniformOverride(const Value &UniVal) {
   UniformOverrides.insert(&UniVal);
 }
 
-bool DivergenceAnalysis::updateTerminator(const Instruction &Term) const {
-  if (Term.getNumSuccessors() <= 1)
-    return false;
-  if (auto *BranchTerm = dyn_cast<BranchInst>(&Term)) {
-    assert(BranchTerm->isConditional());
-    return isDivergent(*BranchTerm->getCondition());
-  }
-  if (auto *SwitchTerm = dyn_cast<SwitchInst>(&Term)) {
-    return isDivergent(*SwitchTerm->getCondition());
-  }
-  if (isa<InvokeInst>(Term)) {
-    return false; // ignore abnormal executions through landingpad
-  }
-
-  llvm_unreachable("unexpected terminator");
-}
-
-bool DivergenceAnalysis::updateNormalInstruction(const Instruction &I) const {
-  // TODO function calls with side effects, etc
-  for (const auto &Op : I.operands()) {
-    if (isDivergent(*Op))
-      return true;
-  }
-  return false;
-}
-
 bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock,
                                              const Value &Val) const {
   const auto *Inst = dyn_cast<const Instruction>(&Val);
@@ -150,32 +126,6 @@ bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock,
   return false;
 }
 
-bool DivergenceAnalysis::updatePHINode(const PHINode &Phi) const {
-  // joining divergent disjoint path in Phi parent block
-  if (!Phi.hasConstantOrUndefValue() && isJoinDivergent(*Phi.getParent())) {
-    return true;
-  }
-
-  // An incoming value could be divergent by itself.
-  // Otherwise, an incoming value could be uniform within the loop
-  // that carries its definition but it may appear divergent
-  // from outside the loop. This happens when divergent loop exits
-  // drop definitions of that uniform value in 
diff erent iterations.
-  //
-  // for (int i = 0; i < n; ++i) { // 'i' is uniform inside the loop
-  //   if (i % thread_id == 0) break;    // divergent loop exit
-  // }
-  // int divI = i;                 // divI is divergent
-  for (size_t i = 0; i < Phi.getNumIncomingValues(); ++i) {
-    const auto *InVal = Phi.getIncomingValue(i);
-    if (isDivergent(*Phi.getIncomingValue(i)) ||
-        isTemporalDivergent(*Phi.getParent(), *InVal)) {
-      return true;
-    }
-  }
-  return false;
-}
-
 bool DivergenceAnalysis::inRegion(const Instruction &I) const {
   return I.getParent() && inRegion(*I.getParent());
 }
@@ -184,35 +134,82 @@ bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const {
   return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB);
 }
 
-static bool usesLiveOut(const Instruction &I, const Loop *DivLoop) {
-  for (auto &Op : I.operands()) {
-    auto *OpInst = dyn_cast<Instruction>(&Op);
+void DivergenceAnalysis::pushUsers(const Value &V) {
+  const auto *I = dyn_cast<const Instruction>(&V);
+
+  if (I && I->isTerminator()) {
+    analyzeControlDivergence(*I);
+    return;
+  }
+
+  for (const auto *User : V.users()) {
+    const auto *UserInst = dyn_cast<const Instruction>(User);
+    if (!UserInst)
+      continue;
+
+    // only compute divergent inside loop
+    if (!inRegion(*UserInst))
+      continue;
+
+    // All users of divergent values are immediate divergent
+    if (markDivergent(*UserInst))
+      Worklist.push_back(UserInst);
+  }
+}
+
+static const Instruction *getIfCarriedInstruction(const Use &U,
+                                                  const Loop &DivLoop) {
+  const auto *I = dyn_cast<const Instruction>(&U);
+  if (!I)
+    return nullptr;
+  if (!DivLoop.contains(I))
+    return nullptr;
+  return I;
+}
+
+void DivergenceAnalysis::analyzeTemporalDivergence(const Instruction &I,
+                                                   const Loop &OuterDivLoop) {
+  if (isAlwaysUniform(I))
+    return;
+  if (isDivergent(I))
+    return;
+
+  LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n");
+  assert((isa<PHINode>(I) || !IsLCSSAForm) &&
+         "In LCSSA form all users of loop-exiting defs are Phi nodes.");
+  for (const Use &Op : I.operands()) {
+    const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop);
     if (!OpInst)
       continue;
-    if (DivLoop->contains(OpInst->getParent()))
-      return true;
+    if (markDivergent(I))
+      pushUsers(I);
+    return;
   }
-  return false;
 }
 
 // marks all users of loop-carried values of the loop headed by LoopHeader as
 // divergent
-void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) {
-  auto *DivLoop = LI.getLoopFor(&LoopHeader);
-  assert(DivLoop && "loopHeader is not actually part of a loop");
+void DivergenceAnalysis::analyzeLoopExitDivergence(const BasicBlock &DivExit,
+                                                   const Loop &OuterDivLoop) {
+  // All users are in immediate exit blocks
+  if (IsLCSSAForm) {
+    for (const auto &Phi : DivExit.phis()) {
+      analyzeTemporalDivergence(Phi, OuterDivLoop);
+    }
+    return;
+  }
 
-  SmallVector<BasicBlock *, 8> TaintStack;
-  DivLoop->getExitBlocks(TaintStack);
+  // For non-LCSSA we have to follow all live out edges wherever they may lead.
+  const BasicBlock &LoopHeader = *OuterDivLoop.getHeader();
+  SmallVector<const BasicBlock *, 8> TaintStack;
+  TaintStack.push_back(&DivExit);
 
   // Otherwise potential users of loop-carried values could be anywhere in the
   // dominance region of DivLoop (including its fringes for phi nodes)
   DenseSet<const BasicBlock *> Visited;
-  for (auto *Block : TaintStack) {
-    Visited.insert(Block);
-  }
-  Visited.insert(&LoopHeader);
+  Visited.insert(&DivExit);
 
-  while (!TaintStack.empty()) {
+  do {
     auto *UserBlock = TaintStack.back();
     TaintStack.pop_back();
 
@@ -220,33 +217,21 @@ void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) {
     if (!inRegion(*UserBlock))
       continue;
 
-    assert(!DivLoop->contains(UserBlock) &&
+    assert(!OuterDivLoop.contains(UserBlock) &&
            "irreducible control flow detected");
 
     // phi nodes at the fringes of the dominance region
     if (!DT.dominates(&LoopHeader, UserBlock)) {
       // all PHI nodes of UserBlock become divergent
       for (auto &Phi : UserBlock->phis()) {
-        Worklist.push_back(&Phi);
+        analyzeTemporalDivergence(Phi, OuterDivLoop);
       }
       continue;
     }
 
-    // taint outside users of values carried by DivLoop
+    // Taint outside users of values carried by OuterDivLoop.
     for (auto &I : *UserBlock) {
-      if (isAlwaysUniform(I))
-        continue;
-      if (isDivergent(I))
-        continue;
-      if (!usesLiveOut(I, DivLoop))
-        continue;
-
-      markDivergent(I);
-      if (I.isTerminator()) {
-        propagateBranchDivergence(I);
-      } else {
-        pushUsers(I);
-      }
+      analyzeTemporalDivergence(I, OuterDivLoop);
     }
 
     // visit all blocks in the dominance region
@@ -256,56 +241,57 @@ void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) {
       }
       TaintStack.push_back(SuccBlock);
     }
-  }
+  } while (!TaintStack.empty());
 }
 
-void DivergenceAnalysis::pushPHINodes(const BasicBlock &Block) {
-  for (const auto &Phi : Block.phis()) {
-    if (isDivergent(Phi))
-      continue;
-    Worklist.push_back(&Phi);
+void DivergenceAnalysis::propagateLoopExitDivergence(const BasicBlock &DivExit,
+                                                     const Loop &InnerDivLoop) {
+  LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n");
+
+  // Find outer-most loop that does not contain \p DivExit
+  const Loop *DivLoop = &InnerDivLoop;
+  const Loop *OuterDivLoop = DivLoop;
+  const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit);
+  const unsigned LoopExitDepth =
+      ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0;
+  while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) {
+    DivergentLoops.insert(DivLoop); // all crossed loops are divergent
+    OuterDivLoop = DivLoop;
+    DivLoop = DivLoop->getParentLoop();
   }
-}
-
-void DivergenceAnalysis::pushUsers(const Value &V) {
-  for (const auto *User : V.users()) {
-    const auto *UserInst = dyn_cast<const Instruction>(User);
-    if (!UserInst)
-      continue;
-
-    if (isDivergent(*UserInst))
-      continue;
+  LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName()
+                    << "\n");
 
-    // only compute divergent inside loop
-    if (!inRegion(*UserInst))
-      continue;
-    Worklist.push_back(UserInst);
-  }
+  analyzeLoopExitDivergence(DivExit, *OuterDivLoop);
 }
 
-bool DivergenceAnalysis::propagateJoinDivergence(const BasicBlock &JoinBlock,
-                                                 const Loop *BranchLoop) {
-  LLVM_DEBUG(dbgs() << "\tpropJoinDiv " << JoinBlock.getName() << "\n");
+// this is a divergent join point - mark all phi nodes as divergent and push
+// them onto the stack.
+void DivergenceAnalysis::taintAndPushPhiNodes(const BasicBlock &JoinBlock) {
+  LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName()
+                    << "\n");
 
   // ignore divergence outside the region
   if (!inRegion(JoinBlock)) {
-    return false;
+    return;
   }
 
   // push non-divergent phi nodes in JoinBlock to the worklist
-  pushPHINodes(JoinBlock);
-
-  // disjoint-paths divergent at JoinBlock
-  markBlockJoinDivergent(JoinBlock);
-
-  // JoinBlock is a divergent loop exit
-  return BranchLoop && !BranchLoop->contains(&JoinBlock);
+  for (const auto &Phi : JoinBlock.phis()) {
+    if (isDivergent(Phi))
+      continue;
+    // FIXME Theoretically ,the 'undef' value could be replaced by any other
+    // value causing spurious divergence.
+    if (Phi.hasConstantOrUndefValue())
+      continue;
+    if (markDivergent(Phi))
+      Worklist.push_back(&Phi);
+  }
 }
 
-void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) {
-  LLVM_DEBUG(dbgs() << "propBranchDiv " << Term.getParent()->getName() << "\n");
-
-  markDivergent(Term);
+void DivergenceAnalysis::analyzeControlDivergence(const Instruction &Term) {
+  LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName()
+                    << "\n");
 
   // Don't propagate divergence from unreachable blocks.
   if (!DT.isReachableFromEntry(Term.getParent()))
@@ -313,104 +299,36 @@ void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) {
 
   const auto *BranchLoop = LI.getLoopFor(Term.getParent());
 
-  // whether there is a divergent loop exit from BranchLoop (if any)
-  bool IsBranchLoopDivergent = false;
+  const auto &DivDesc = SDA.getJoinBlocks(Term);
 
-  // iterate over all blocks reachable by disjoint from Term within the loop
-  // also iterates over loop exits that become divergent due to Term.
-  for (const auto *JoinBlock : SDA.join_blocks(Term)) {
-    IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop);
+  // Iterate over all blocks now reachable by a disjoint path join
+  for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
+    taintAndPushPhiNodes(*JoinBlock);
   }
 
-  // Branch loop is a divergent loop due to the divergent branch in Term
-  if (IsBranchLoopDivergent) {
-    assert(BranchLoop);
-    if (!DivergentLoops.insert(BranchLoop).second) {
-      return;
-    }
-    propagateLoopDivergence(*BranchLoop);
-  }
-}
-
-void DivergenceAnalysis::propagateLoopDivergence(const Loop &ExitingLoop) {
-  LLVM_DEBUG(dbgs() << "propLoopDiv " << ExitingLoop.getName() << "\n");
-
-  // don't propagate beyond region
-  if (!inRegion(*ExitingLoop.getHeader()))
-    return;
-
-  const auto *BranchLoop = ExitingLoop.getParentLoop();
-
-  // Uses of loop-carried values could occur anywhere
-  // within the dominance region of the definition. All loop-carried
-  // definitions are dominated by the loop header (reducible control).
-  // Thus all users have to be in the dominance region of the loop header,
-  // except PHI nodes that can also live at the fringes of the dom region
-  // (incoming defining value).
-  if (!IsLCSSAForm)
-    taintLoopLiveOuts(*ExitingLoop.getHeader());
-
-  // whether there is a divergent loop exit from BranchLoop (if any)
-  bool IsBranchLoopDivergent = false;
-
-  // iterate over all blocks reachable by disjoint paths from exits of
-  // ExitingLoop also iterates over loop exits (of BranchLoop) that in turn
-  // become divergent.
-  for (const auto *JoinBlock : SDA.join_blocks(ExitingLoop)) {
-    IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop);
-  }
-
-  // Branch loop is a divergent due to divergent loop exit in ExitingLoop
-  if (IsBranchLoopDivergent) {
-    assert(BranchLoop);
-    if (!DivergentLoops.insert(BranchLoop).second) {
-      return;
-    }
-    propagateLoopDivergence(*BranchLoop);
+  assert(DivDesc.LoopDivBlocks.empty() || BranchLoop);
+  for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) {
+    propagateLoopExitDivergence(*DivExitBlock, *BranchLoop);
   }
 }
 
 void DivergenceAnalysis::compute() {
-  for (auto *DivVal : DivergentValues) {
+  // Initialize worklist.
+  auto DivValuesCopy = DivergentValues;
+  for (const auto *DivVal : DivValuesCopy) {
+    assert(isDivergent(*DivVal) && "Worklist invariant violated!");
     pushUsers(*DivVal);
   }
 
-  // propagate divergence
+  // All values on the Worklist are divergent.
+  // Their users may not have been updated yed.
   while (!Worklist.empty()) {
     const Instruction &I = *Worklist.back();
     Worklist.pop_back();
 
-    // maintain uniformity of overrides
-    if (isAlwaysUniform(I))
-      continue;
-
-    bool WasDivergent = isDivergent(I);
-    if (WasDivergent)
-      continue;
-
-    // propagate divergence caused by terminator
-    if (I.isTerminator()) {
-      if (updateTerminator(I)) {
-        // propagate control divergence to affected instructions
-        propagateBranchDivergence(I);
-        continue;
-      }
-    }
-
-    // update divergence of I due to divergent operands
-    bool DivergentUpd = false;
-    const auto *Phi = dyn_cast<const PHINode>(&I);
-    if (Phi) {
-      DivergentUpd = updatePHINode(*Phi);
-    } else {
-      DivergentUpd = updateNormalInstruction(I);
-    }
-
     // propagate value divergence to users
-    if (DivergentUpd) {
-      markDivergent(I);
-      pushUsers(I);
-    }
+    assert(isDivergent(I) && "Worklist invariant violated!");
+    pushUsers(I);
   }
 }
 
@@ -444,7 +362,7 @@ GPUDivergenceAnalysis::GPUDivergenceAnalysis(Function &F,
                                              const PostDominatorTree &PDT,
                                              const LoopInfo &LI,
                                              const TargetTransformInfo &TTI)
-    : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, false) {
+    : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, /* LCSSA */ false) {
   for (auto &I : instructions(F)) {
     if (TTI.isSourceOfDivergence(&I)) {
       DA.markDivergent(I);

diff  --git a/llvm/lib/Analysis/SyncDependenceAnalysis.cpp b/llvm/lib/Analysis/SyncDependenceAnalysis.cpp
index 36bef705d4f3..0771bb52c4f4 100644
--- a/llvm/lib/Analysis/SyncDependenceAnalysis.cpp
+++ b/llvm/lib/Analysis/SyncDependenceAnalysis.cpp
@@ -1,4 +1,4 @@
-//==- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation -==//
+//===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -107,271 +107,353 @@
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Function.h"
 
+#include <functional>
 #include <stack>
 #include <unordered_set>
 
 #define DEBUG_TYPE "sync-dependence"
 
+// The SDA algorithm operates on a modified CFG - we modify the edges leaving
+// loop headers as follows:
+//
+// * We remove all edges leaving all loop headers.
+// * We add additional edges from the loop headers to their exit blocks.
+//
+// The modification is virtual, that is whenever we visit a loop header we
+// pretend it had 
diff erent successors.
+namespace {
+using namespace llvm;
+
+// Custom Post-Order Traveral
+//
+// We cannot use the vanilla (R)PO computation of LLVM because:
+// * We (virtually) modify the CFG.
+// * We want a loop-compact block enumeration, that is the numbers assigned by
+//   the traveral to the blocks of a loop are an interval.
+using POCB = std::function<void(const BasicBlock &)>;
+using VisitedSet = std::set<const BasicBlock *>;
+using BlockStack = std::vector<const BasicBlock *>;
+
+// forward
+static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
+                          VisitedSet &Finalized);
+
+// for a nested region (top-level loop or nested loop)
+static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop,
+                           POCB CallBack, VisitedSet &Finalized) {
+  const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr;
+  while (!Stack.empty()) {
+    const auto *NextBB = Stack.back();
+
+    auto *NestedLoop = LI.getLoopFor(NextBB);
+    bool IsNestedLoop = NestedLoop != Loop;
+
+    // Treat the loop as a node
+    if (IsNestedLoop) {
+      SmallVector<BasicBlock *, 3> NestedExits;
+      NestedLoop->getUniqueExitBlocks(NestedExits);
+      bool PushedNodes = false;
+      for (const auto *NestedExitBB : NestedExits) {
+        if (NestedExitBB == LoopHeader)
+          continue;
+        if (Loop && !Loop->contains(NestedExitBB))
+          continue;
+        if (Finalized.count(NestedExitBB))
+          continue;
+        PushedNodes = true;
+        Stack.push_back(NestedExitBB);
+      }
+      if (!PushedNodes) {
+        // All loop exits finalized -> finish this node
+        Stack.pop_back();
+        computeLoopPO(LI, *NestedLoop, CallBack, Finalized);
+      }
+      continue;
+    }
+
+    // DAG-style
+    bool PushedNodes = false;
+    for (const auto *SuccBB : successors(NextBB)) {
+      if (SuccBB == LoopHeader)
+        continue;
+      if (Loop && !Loop->contains(SuccBB))
+        continue;
+      if (Finalized.count(SuccBB))
+        continue;
+      PushedNodes = true;
+      Stack.push_back(SuccBB);
+    }
+    if (!PushedNodes) {
+      // Never push nodes twice
+      Stack.pop_back();
+      if (!Finalized.insert(NextBB).second)
+        continue;
+      CallBack(*NextBB);
+    }
+  }
+}
+
+static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) {
+  VisitedSet Finalized;
+  BlockStack Stack;
+  Stack.reserve(24); // FIXME made-up number
+  Stack.push_back(&F.getEntryBlock());
+  computeStackPO(Stack, LI, nullptr, CallBack, Finalized);
+}
+
+static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
+                          VisitedSet &Finalized) {
+  /// Call CallBack on all loop blocks.
+  std::vector<const BasicBlock *> Stack;
+  const auto *LoopHeader = Loop.getHeader();
+
+  // Visit the header last
+  Finalized.insert(LoopHeader);
+  CallBack(*LoopHeader);
+
+  // Initialize with immediate successors
+  for (const auto *BB : successors(LoopHeader)) {
+    if (!Loop.contains(BB))
+      continue;
+    if (BB == LoopHeader)
+      continue;
+    Stack.push_back(BB);
+  }
+
+  // Compute PO inside region
+  computeStackPO(Stack, LI, &Loop, CallBack, Finalized);
+}
+
+} // namespace
+
 namespace llvm {
 
-ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet;
+ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc;
 
 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
                                                const PostDominatorTree &PDT,
                                                const LoopInfo &LI)
-    : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {}
+    : DT(DT), PDT(PDT), LI(LI) {
+  computeTopLevelPO(*DT.getRoot()->getParent(), LI,
+                    [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); });
+}
 
 SyncDependenceAnalysis::~SyncDependenceAnalysis() {}
 
-using FunctionRPOT = ReversePostOrderTraversal<const Function *>;
-
 // divergence propagator for reducible CFGs
 struct DivergencePropagator {
-  const FunctionRPOT &FuncRPOT;
+  const ModifiedPO &LoopPOT;
   const DominatorTree &DT;
   const PostDominatorTree &PDT;
   const LoopInfo &LI;
-
-  // identified join points
-  std::unique_ptr<ConstBlockSet> JoinBlocks;
-
-  // reached loop exits (by a path disjoint to a path to the loop header)
-  SmallPtrSet<const BasicBlock *, 4> ReachedLoopExits;
-
-  // if DefMap[B] == C then C is the dominating definition at block B
-  // if DefMap[B] ~ undef then we haven't seen B yet
-  // if DefMap[B] == B then B is a join point of disjoint paths from X or B is
-  // an immediate successor of X (initial value).
-  using DefiningBlockMap = std::map<const BasicBlock *, const BasicBlock *>;
-  DefiningBlockMap DefMap;
-
-  // all blocks with pending visits
-  std::unordered_set<const BasicBlock *> PendingUpdates;
-
-  DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT,
-                       const PostDominatorTree &PDT, const LoopInfo &LI)
-      : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI),
-        JoinBlocks(new ConstBlockSet) {}
-
-  // set the definition at @block and mark @block as pending for a visit
-  void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) {
-    bool WasAdded = DefMap.emplace(&Block, &DefBlock).second;
-    if (WasAdded)
-      PendingUpdates.insert(&Block);
-  }
+  const BasicBlock &DivTermBlock;
+
+  // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at
+  //   block B
+  // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet
+  // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths
+  // from X or B is an immediate successor of X (initial value).
+  using BlockLabelVec = std::vector<const BasicBlock *>;
+  BlockLabelVec BlockLabels;
+  // divergent join and loop exit descriptor.
+  std::unique_ptr<ControlDivergenceDesc> DivDesc;
+
+  DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT,
+                       const PostDominatorTree &PDT, const LoopInfo &LI,
+                       const BasicBlock &DivTermBlock)
+      : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock),
+        BlockLabels(LoopPOT.size(), nullptr),
+        DivDesc(new ControlDivergenceDesc) {}
 
   void printDefs(raw_ostream &Out) {
-    Out << "Propagator::DefMap {\n";
-    for (const auto *Block : FuncRPOT) {
-      auto It = DefMap.find(Block);
-      Out << Block->getName() << " : ";
-      if (It == DefMap.end()) {
-        Out << "\n";
+    Out << "Propagator::BlockLabels {\n";
+    for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) {
+      const auto *Label = BlockLabels[BlockIdx];
+      Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx
+          << ") : ";
+      if (!Label) {
+        Out << "<null>\n";
       } else {
-        const auto *DefBlock = It->second;
-        Out << (DefBlock ? DefBlock->getName() : "<null>") << "\n";
+        Out << Label->getName() << "\n";
       }
     }
     Out << "}\n";
   }
 
-  // process @succBlock with reaching definition @defBlock
-  // the original divergent branch was in @parentLoop (if any)
-  void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop,
-                      const BasicBlock &DefBlock) {
+  // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
+  // causes a divergent join.
+  bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) {
+    auto SuccIdx = LoopPOT.getIndexOf(SuccBlock);
 
-    // @succBlock is a loop exit
-    if (ParentLoop && !ParentLoop->contains(&SuccBlock)) {
-      DefMap.emplace(&SuccBlock, &DefBlock);
-      ReachedLoopExits.insert(&SuccBlock);
-      return;
+    // unset or same reaching label
+    const auto *OldLabel = BlockLabels[SuccIdx];
+    if (!OldLabel || (OldLabel == &PushedLabel)) {
+      BlockLabels[SuccIdx] = &PushedLabel;
+      return false;
     }
 
-    // first reaching def?
-    auto ItLastDef = DefMap.find(&SuccBlock);
-    if (ItLastDef == DefMap.end()) {
-      addPending(SuccBlock, DefBlock);
-      return;
-    }
+    // Update the definition
+    BlockLabels[SuccIdx] = &SuccBlock;
+    return true;
+  }
 
-    // a join of at least two definitions
-    if (ItLastDef->second != &DefBlock) {
-      // do we know this join already?
-      if (!JoinBlocks->insert(&SuccBlock).second)
-        return;
+  // visiting a virtual loop exit edge from the loop header --> temporal
+  // divergence on join
+  bool visitLoopExitEdge(const BasicBlock &ExitBlock,
+                         const BasicBlock &DefBlock, bool FromParentLoop) {
+    // Pushing from a non-parent loop cannot cause temporal divergence.
+    if (!FromParentLoop)
+      return visitEdge(ExitBlock, DefBlock);
 
-      // update the definition
-      addPending(SuccBlock, SuccBlock);
-    }
+    if (!computeJoin(ExitBlock, DefBlock))
+      return false;
+
+    // Identified a divergent loop exit
+    DivDesc->LoopDivBlocks.insert(&ExitBlock);
+    LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName()
+                      << "\n");
+    return true;
   }
 
-  // find all blocks reachable by two disjoint paths from @rootTerm.
-  // This method works for both divergent terminators and loops with
-  // divergent exits.
-  // @rootBlock is either the block containing the branch or the header of the
-  // divergent loop.
-  // @nodeSuccessors is the set of successors of the node (Loop or Terminator)
-  // headed by @rootBlock.
-  // @parentLoop is the parent loop of the Loop or the loop that contains the
-  // Terminator.
-  template <typename SuccessorIterable>
-  std::unique_ptr<ConstBlockSet>
-  computeJoinPoints(const BasicBlock &RootBlock,
-                    SuccessorIterable NodeSuccessors, const Loop *ParentLoop) {
-    assert(JoinBlocks);
-
-    LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints. Parent loop: "
-                      << (ParentLoop ? ParentLoop->getName() : "<null>")
+  // process \p SuccBlock with reaching definition \p DefBlock
+  bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) {
+    if (!computeJoin(SuccBlock, DefBlock))
+      return false;
+
+    // Divergent, disjoint paths join.
+    DivDesc->JoinDivBlocks.insert(&SuccBlock);
+    LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName());
+    return true;
+  }
+
+  std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() {
+    assert(DivDesc);
+
+    LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName()
                       << "\n");
 
+    const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock);
+
+    // Early stopping criterion
+    int FloorIdx = LoopPOT.size() - 1;
+    const BasicBlock *FloorLabel = nullptr;
+
     // bootstrap with branch targets
-    for (const auto *SuccBlock : NodeSuccessors) {
-      DefMap.emplace(SuccBlock, SuccBlock);
+    int BlockIdx = 0;
 
-      if (ParentLoop && !ParentLoop->contains(SuccBlock)) {
-        // immediate loop exit from node.
-        ReachedLoopExits.insert(SuccBlock);
-      } else {
-        // regular successor
-        PendingUpdates.insert(SuccBlock);
-      }
-    }
+    for (const auto *SuccBlock : successors(&DivTermBlock)) {
+      auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock);
+      BlockLabels[SuccIdx] = SuccBlock;
 
-    LLVM_DEBUG(dbgs() << "SDA: rpo order:\n"; for (const auto *RpoBlock
-                                                   : FuncRPOT) {
-      dbgs() << "- " << RpoBlock->getName() << "\n";
-    });
+      // Find the successor with the highest index to start with
+      BlockIdx = std::max<int>(BlockIdx, SuccIdx);
+      FloorIdx = std::min<int>(FloorIdx, SuccIdx);
 
-    auto ItBeginRPO = FuncRPOT.begin();
-    auto ItEndRPO = FuncRPOT.end();
+      // Identify immediate divergent loop exits
+      if (!DivBlockLoop)
+        continue;
 
-    // skip until term (TODO RPOT won't let us start at @term directly)
-    for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) {
-      assert(ItBeginRPO != ItEndRPO && "Unable to find RootBlock");
+      const auto *BlockLoop = LI.getLoopFor(SuccBlock);
+      if (BlockLoop && DivBlockLoop->contains(BlockLoop))
+        continue;
+      DivDesc->LoopDivBlocks.insert(SuccBlock);
+      LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: "
+                        << SuccBlock->getName() << "\n");
     }
 
     // propagate definitions at the immediate successors of the node in RPO
-    auto ItBlockRPO = ItBeginRPO;
-    while ((++ItBlockRPO != ItEndRPO) && !PendingUpdates.empty()) {
-      const auto *Block = *ItBlockRPO;
-      LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
+    for (; BlockIdx >= FloorIdx; --BlockIdx) {
+      LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs()));
 
-      // skip Block if not pending update
-      auto ItPending = PendingUpdates.find(Block);
-      if (ItPending == PendingUpdates.end())
+      // Any label available here
+      const auto *Label = BlockLabels[BlockIdx];
+      if (!Label)
         continue;
-      PendingUpdates.erase(ItPending);
 
-      // propagate definition at Block to its successors
-      auto ItDef = DefMap.find(Block);
-      const auto *DefBlock = ItDef->second;
-      assert(DefBlock);
+      // Ok. Get the block
+      const auto *Block = LoopPOT.getBlockAt(BlockIdx);
+      LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
 
       auto *BlockLoop = LI.getLoopFor(Block);
-      if (ParentLoop &&
-          (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) {
-        // if the successor is the header of a nested loop pretend its a
-        // single node with the loop's exits as successors
+      bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block;
+      bool CausedJoin = false;
+      int LoweredFloorIdx = FloorIdx;
+      if (IsLoopHeader) {
+        // Disconnect from immediate successors and propagate directly to loop
+        // exits.
         SmallVector<BasicBlock *, 4> BlockLoopExits;
         BlockLoop->getExitBlocks(BlockLoopExits);
+
+        bool IsParentLoop = BlockLoop->contains(&DivTermBlock);
         for (const auto *BlockLoopExit : BlockLoopExits) {
-          visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock);
+          CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop);
+          LoweredFloorIdx = std::min<int>(LoweredFloorIdx,
+                                          LoopPOT.getIndexOf(*BlockLoopExit));
         }
-
       } else {
-        // the successors are either on the same loop level or loop exits
+        // Acyclic successor case
         for (const auto *SuccBlock : successors(Block)) {
-          visitSuccessor(*SuccBlock, ParentLoop, *DefBlock);
+          CausedJoin |= visitEdge(*SuccBlock, *Label);
+          LoweredFloorIdx =
+              std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock));
         }
       }
-    }
 
-    LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
-
-    // We need to know the definition at the parent loop header to decide
-    // whether the definition at the header is 
diff erent from the definition at
-    // the loop exits, which would indicate a divergent loop exits.
-    //
-    // A // loop header
-    // |
-    // B // nested loop header
-    // |
-    // C -> X (exit from B loop) -..-> (A latch)
-    // |
-    // D -> back to B (B latch)
-    // |
-    // proper exit from both loops
-    //
-    // analyze reached loop exits
-    if (!ReachedLoopExits.empty()) {
-      const BasicBlock *ParentLoopHeader =
-          ParentLoop ? ParentLoop->getHeader() : nullptr;
-
-      assert(ParentLoop);
-      auto ItHeaderDef = DefMap.find(ParentLoopHeader);
-      const auto *HeaderDefBlock =
-          (ItHeaderDef == DefMap.end()) ? nullptr : ItHeaderDef->second;
-
-      LLVM_DEBUG(printDefs(dbgs()));
-      assert(HeaderDefBlock && "no definition at header of carrying loop");
-
-      for (const auto *ExitBlock : ReachedLoopExits) {
-        auto ItExitDef = DefMap.find(ExitBlock);
-        assert((ItExitDef != DefMap.end()) &&
-               "no reaching def at reachable loop exit");
-        if (ItExitDef->second != HeaderDefBlock) {
-          JoinBlocks->insert(ExitBlock);
-        }
+      // Floor update
+      if (CausedJoin) {
+        // 1. Different labels pushed to successors
+        FloorIdx = LoweredFloorIdx;
+      } else if (FloorLabel != Label) {
+        // 2. No join caused BUT we pushed a label that is 
diff erent than the
+        // last pushed label
+        FloorIdx = LoweredFloorIdx;
+        FloorLabel = Label;
       }
     }
 
-    return std::move(JoinBlocks);
+    LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
+
+    return std::move(DivDesc);
   }
 };
 
-const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) {
-  using LoopExitVec = SmallVector<BasicBlock *, 4>;
-  LoopExitVec LoopExits;
-  Loop.getExitBlocks(LoopExits);
-  if (LoopExits.size() < 1) {
-    return EmptyBlockSet;
+static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) {
+  Out << "[";
+  bool First = true;
+  for (const auto *BB : Blocks) {
+    if (!First)
+      Out << ", ";
+    First = false;
+    Out << BB->getName();
   }
-
-  // already available in cache?
-  auto ItCached = CachedLoopExitJoins.find(&Loop);
-  if (ItCached != CachedLoopExitJoins.end()) {
-    return *ItCached->second;
-  }
-
-  // compute all join points
-  DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
-  auto JoinBlocks = Propagator.computeJoinPoints<const LoopExitVec &>(
-      *Loop.getHeader(), LoopExits, Loop.getParentLoop());
-
-  auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks));
-  assert(ItInserted.second);
-  return *ItInserted.first->second;
+  Out << "]";
 }
 
-const ConstBlockSet &
-SyncDependenceAnalysis::join_blocks(const Instruction &Term) {
+const ControlDivergenceDesc &
+SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) {
   // trivial case
-  if (Term.getNumSuccessors() < 1) {
-    return EmptyBlockSet;
+  if (Term.getNumSuccessors() <= 1) {
+    return EmptyDivergenceDesc;
   }
 
   // already available in cache?
-  auto ItCached = CachedBranchJoins.find(&Term);
-  if (ItCached != CachedBranchJoins.end())
+  auto ItCached = CachedControlDivDescs.find(&Term);
+  if (ItCached != CachedControlDivDescs.end())
     return *ItCached->second;
 
   // compute all join points
-  DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
+  // Special handling of divergent loop exits is not needed for LCSSA
   const auto &TermBlock = *Term.getParent();
-  auto JoinBlocks = Propagator.computeJoinPoints<const_succ_range>(
-      TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock));
+  DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock);
+  auto DivDesc = Propagator.computeJoinPoints();
+
+  LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n";
+             dbgs() << "JoinDivBlocks: ";
+             printBlockSet(DivDesc->JoinDivBlocks, dbgs());
+             dbgs() << "\nLoopDivBlocks: ";
+             printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";);
 
-  auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks));
+  auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc));
   assert(ItInserted.second);
   return *ItInserted.first->second;
 }

diff  --git a/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll b/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll
index 12e2b0ffd443..774e995c7ca2 100644
--- a/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll
+++ b/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll
@@ -119,9 +119,8 @@ L:
   br i1 %uni.cond, label %D, label %G
 
 X:
-  %div.merge.x = phi i32 [ %a, %entry ], [ %uni.merge.h, %B ] ; temporal divergent phi
+  %uni.merge.x = phi i32 [ %a, %entry ], [ %uni.merge.h, %B ]
   br i1 %uni.cond, label %Y, label %exit
-; CHECK: DIVERGENT: %div.merge.x =
 
 Y:
   %div.merge.y = phi i32 [ 42, %X ], [ %b, %C ]

diff  --git a/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/trivial-join-at-loop-exit.ll b/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/trivial-join-at-loop-exit.ll
index 8ad848af41f5..b872dd8966bc 100644
--- a/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/trivial-join-at-loop-exit.ll
+++ b/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/trivial-join-at-loop-exit.ll
@@ -1,7 +1,4 @@
 ; RUN: opt -mtriple amdgcn-unknown-amdhsa -analyze -divergence -use-gpu-divergence-analysis %s | FileCheck %s
-; XFAIL: *
-
-; https://bugs.llvm.org/show_bug.cgi?id=46372
 
 ; CHECK: bb2:
 ; CHECK-NOT: DIVERGENT:       %Guard.bb2 = phi i1 [ true, %bb1 ], [ false, %bb0 ]


        


More information about the llvm-commits mailing list