[llvm] [UniformAnalysis] Use Immediate postDom as last join (PR #140013)

via llvm-commits llvm-commits at lists.llvm.org
Wed May 14 23:53:40 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Junjie Gu (jgu222)

<details>
<summary>Changes</summary>

Given a divergent block, computeJoinPoints uses FloorIdx to do early stopping. But it is incorrect for some cases (shown in the two new lit tests).

This change uses the immediate post-dominator as the last join to check for early stopping. It adds post-dominator to genericUniformityImpl in order to get immediate postDom.

---

Patch is 22.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140013.diff


8 Files Affected:

- (modified) llvm/include/llvm/ADT/GenericSSAContext.h (+4) 
- (modified) llvm/include/llvm/ADT/GenericUniformityImpl.h (+44-47) 
- (modified) llvm/include/llvm/ADT/GenericUniformityInfo.h (+3-1) 
- (modified) llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h (+3-1) 
- (modified) llvm/lib/Analysis/UniformityAnalysis.cpp (+8-2) 
- (modified) llvm/lib/CodeGen/MachineUniformityAnalysis.cpp (+11-4) 
- (added) llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll (+78) 
- (added) llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_loop.ll (+82) 


``````````diff
diff --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h
index 6aa3a8b9b6e0b..e99d4b1c6dd45 100644
--- a/llvm/include/llvm/ADT/GenericSSAContext.h
+++ b/llvm/include/llvm/ADT/GenericSSAContext.h
@@ -77,6 +77,10 @@ template <typename _FunctionT> class GenericSSAContext {
   // a given funciton.
   using DominatorTreeT = DominatorTreeBase<BlockT, false>;
 
+  // A post-dominator tree provides the post-dominance relation between
+  // basic blocks in a given funciton.
+  using PostDominatorTreeT = DominatorTreeBase<BlockT, true>;
+
   GenericSSAContext() = default;
   GenericSSAContext(const FunctionT *F) : F(F) {}
 
diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h
index d10355fff1bea..f404577bb7e56 100644
--- a/llvm/include/llvm/ADT/GenericUniformityImpl.h
+++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h
@@ -263,6 +263,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
 public:
   using BlockT = typename ContextT::BlockT;
   using DominatorTreeT = typename ContextT::DominatorTreeT;
+  using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
   using FunctionT = typename ContextT::FunctionT;
   using ValueRefT = typename ContextT::ValueRefT;
   using InstructionT = typename ContextT::InstructionT;
@@ -296,7 +297,9 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
   using DivergencePropagatorT = DivergencePropagator<ContextT>;
 
   GenericSyncDependenceAnalysis(const ContextT &Context,
-                                const DominatorTreeT &DT, const CycleInfoT &CI);
+                                const DominatorTreeT &DT,
+                                const PostDominatorTreeT &PDT,
+                                const CycleInfoT &CI);
 
   /// \brief Computes divergent join points and cycle exits caused by branch
   /// divergence in \p Term.
@@ -315,6 +318,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
   ModifiedPO CyclePO;
 
   const DominatorTreeT &DT;
+  const PostDominatorTreeT &PDT;
   const CycleInfoT &CI;
 
   DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
@@ -336,6 +340,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
   using UseT = typename ContextT::UseT;
   using InstructionT = typename ContextT::InstructionT;
   using DominatorTreeT = typename ContextT::DominatorTreeT;
+  using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
 
   using CycleInfoT = GenericCycleInfo<ContextT>;
   using CycleT = typename CycleInfoT::CycleT;
@@ -348,10 +353,12 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
   using TemporalDivergenceTuple =
       std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;
 
-  GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI,
+  GenericUniformityAnalysisImpl(const DominatorTreeT &DT,
+                                const PostDominatorTreeT &PDT,
+                                const CycleInfoT &CI,
                                 const TargetTransformInfo *TTI)
       : Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI),
-        TTI(TTI), DT(DT), SDA(Context, DT, CI) {}
+        TTI(TTI), DT(DT), PDT(PDT), SDA(Context, DT, PDT, CI) {}
 
   void initialize();
 
@@ -435,6 +442,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
 
 private:
   const DominatorTreeT &DT;
+  const PostDominatorTreeT &PDT;
 
   // Recognized cycles with divergent exits.
   SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
@@ -493,6 +501,7 @@ template <typename ContextT> class DivergencePropagator {
 public:
   using BlockT = typename ContextT::BlockT;
   using DominatorTreeT = typename ContextT::DominatorTreeT;
+  using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
   using FunctionT = typename ContextT::FunctionT;
   using ValueRefT = typename ContextT::ValueRefT;
 
@@ -507,6 +516,7 @@ template <typename ContextT> class DivergencePropagator {
 
   const ModifiedPO &CyclePOT;
   const DominatorTreeT &DT;
+  const PostDominatorTreeT &PDT;
   const CycleInfoT &CI;
   const BlockT &DivTermBlock;
   const ContextT &Context;
@@ -522,10 +532,11 @@ template <typename ContextT> class DivergencePropagator {
   BlockLabelMapT &BlockLabels;
 
   DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
-                       const CycleInfoT &CI, const BlockT &DivTermBlock)
-      : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
-        Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
-        BlockLabels(DivDesc->BlockLabels) {}
+                       const PostDominatorTreeT &PDT, const CycleInfoT &CI,
+                       const BlockT &DivTermBlock)
+      : CyclePOT(CyclePOT), DT(DT), PDT(PDT), CI(CI),
+        DivTermBlock(DivTermBlock), Context(CI.getSSAContext()),
+        DivDesc(new DivergenceDescriptorT), BlockLabels(DivDesc->BlockLabels) {}
 
   void printDefs(raw_ostream &Out) {
     Out << "Propagator::BlockLabels {\n";
@@ -542,6 +553,12 @@ template <typename ContextT> class DivergencePropagator {
     Out << "}\n";
   }
 
+  const BlockT *getIPDom(const BlockT *B) {
+    const auto *Node = PDT.getNode(B);
+    const auto *IPDomNode = Node->getIDom();
+    return IPDomNode->getBlock();
+  }
+
   // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
   // causes a divergent join.
   bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
@@ -610,10 +627,11 @@ template <typename ContextT> class DivergencePropagator {
     LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
                       << Context.print(&DivTermBlock) << "\n");
 
-    // Early stopping criterion
-    int FloorIdx = CyclePOT.size() - 1;
-    const BlockT *FloorLabel = nullptr;
-    int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
+    // Immediate Post-dominator of DivTermBlock is the last join
+    // to visit.
+    const auto *ImmPDom = getIPDom(&DivTermBlock);
+
+    LLVM_DEBUG(dbgs() << "Last join: " << Context.print(ImmPDom) << "\n");
 
     // Bootstrap with branch targets
     auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
@@ -626,34 +644,29 @@ template <typename ContextT> class DivergencePropagator {
         LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
                           << Context.print(SuccBlock) << "\n");
       }
-      auto SuccIdx = CyclePOT.getIndex(SuccBlock);
       visitEdge(*SuccBlock, *SuccBlock);
-      FloorIdx = std::min<int>(FloorIdx, SuccIdx);
     }
 
     while (true) {
       auto BlockIdx = FreshLabels.find_last();
-      if (BlockIdx == -1 || BlockIdx < FloorIdx)
+      if (BlockIdx == -1)
         break;
 
       LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
 
       FreshLabels.reset(BlockIdx);
-      if (BlockIdx == DivTermIdx) {
-        LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
+      const auto *Block = CyclePOT[BlockIdx];
+      if (Block == ImmPDom) {
+        LLVM_DEBUG(dbgs() << "Skipping ImmPDom\n");
         continue;
       }
 
-      const auto *Block = CyclePOT[BlockIdx];
       LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
                         << BlockIdx << "\n");
 
       const auto *Label = BlockLabels[Block];
       assert(Label);
 
-      bool CausedJoin = false;
-      int LoweredFloorIdx = FloorIdx;
-
       // If the current block is the header of a reducible cycle that
       // contains the divergent branch, then the label should be
       // propagated to the cycle exits. Such a header is the "last
@@ -681,28 +694,11 @@ template <typename ContextT> class DivergencePropagator {
       if (const auto *BlockCycle = getReducibleParent(Block)) {
         SmallVector<BlockT *, 4> BlockCycleExits;
         BlockCycle->getExitBlocks(BlockCycleExits);
-        for (auto *BlockCycleExit : BlockCycleExits) {
-          CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
-          LoweredFloorIdx =
-              std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
-        }
+        for (auto *BlockCycleExit : BlockCycleExits)
+          visitCycleExitEdge(*BlockCycleExit, *Label);
       } else {
-        for (const auto *SuccBlock : successors(Block)) {
-          CausedJoin |= visitEdge(*SuccBlock, *Label);
-          LoweredFloorIdx =
-              std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
-        }
-      }
-
-      // Floor update
-      if (CausedJoin) {
-        // 1. Different labels pushed to successors
-        FloorIdx = LoweredFloorIdx;
-      } else if (FloorLabel != Label) {
-        // 2. No join caused BUT we pushed a label that is different than the
-        // last pushed label
-        FloorIdx = LoweredFloorIdx;
-        FloorLabel = Label;
+        for (const auto *SuccBlock : successors(Block))
+          visitEdge(*SuccBlock, *Label);
       }
     }
 
@@ -742,8 +738,9 @@ typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
 
 template <typename ContextT>
 llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
-    const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
-    : CyclePO(Context), DT(DT), CI(CI) {
+    const ContextT &Context, const DominatorTreeT &DT,
+    const PostDominatorTreeT &PDT, const CycleInfoT &CI)
+    : CyclePO(Context), DT(DT), PDT(PDT), CI(CI) {
   CyclePO.compute(CI);
 }
 
@@ -761,7 +758,7 @@ auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
     return *ItCached->second;
 
   // compute all join points
-  DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
+  DivergencePropagatorT Propagator(CyclePO, DT, PDT, CI, *DivTermBlock);
   auto DivDesc = Propagator.computeJoinPoints();
 
   auto printBlockSet = [&](ConstBlockSet &Blocks) {
@@ -1155,9 +1152,9 @@ bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
 
 template <typename ContextT>
 GenericUniformityInfo<ContextT>::GenericUniformityInfo(
-    const DominatorTreeT &DT, const CycleInfoT &CI,
-    const TargetTransformInfo *TTI) {
-  DA.reset(new ImplT{DT, CI, TTI});
+    const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
+    const CycleInfoT &CI, const TargetTransformInfo *TTI) {
+  DA.reset(new ImplT{DT, PDT, CI, TTI});
 }
 
 template <typename ContextT>
diff --git a/llvm/include/llvm/ADT/GenericUniformityInfo.h b/llvm/include/llvm/ADT/GenericUniformityInfo.h
index 9376fa6ee0bae..62d35582823dc 100644
--- a/llvm/include/llvm/ADT/GenericUniformityInfo.h
+++ b/llvm/include/llvm/ADT/GenericUniformityInfo.h
@@ -35,6 +35,7 @@ template <typename ContextT> class GenericUniformityInfo {
   using UseT = typename ContextT::UseT;
   using InstructionT = typename ContextT::InstructionT;
   using DominatorTreeT = typename ContextT::DominatorTreeT;
+  using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
   using ThisT = GenericUniformityInfo<ContextT>;
 
   using CycleInfoT = GenericCycleInfo<ContextT>;
@@ -43,7 +44,8 @@ template <typename ContextT> class GenericUniformityInfo {
   using TemporalDivergenceTuple =
       std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;
 
-  GenericUniformityInfo(const DominatorTreeT &DT, const CycleInfoT &CI,
+  GenericUniformityInfo(const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
+                        const CycleInfoT &CI,
                         const TargetTransformInfo *TTI = nullptr);
   GenericUniformityInfo() = default;
   GenericUniformityInfo(GenericUniformityInfo &&) = default;
diff --git a/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h b/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
index e8c0dc9b43823..03fc9ebfcf442 100644
--- a/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
+++ b/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
@@ -18,6 +18,7 @@
 #include "llvm/CodeGen/MachineCycleAnalysis.h"
 #include "llvm/CodeGen/MachineDominators.h"
 #include "llvm/CodeGen/MachinePassManager.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
 #include "llvm/CodeGen/MachineSSAContext.h"
 
 namespace llvm {
@@ -31,7 +32,8 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>;
 /// everything is uniform.
 MachineUniformityInfo computeMachineUniformityInfo(
     MachineFunction &F, const MachineCycleInfo &cycleInfo,
-    const MachineDominatorTree &domTree, bool HasBranchDivergence);
+    const MachineDominatorTree &domTree,
+    const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence);
 
 /// Legacy analysis pass which computes a \ref MachineUniformityInfo.
 class MachineUniformityAnalysisPass : public MachineFunctionPass {
diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp
index 2101fdfacfc8f..a724a8c26d7db 100644
--- a/llvm/lib/Analysis/UniformityAnalysis.cpp
+++ b/llvm/lib/Analysis/UniformityAnalysis.cpp
@@ -9,6 +9,7 @@
 #include "llvm/Analysis/UniformityAnalysis.h"
 #include "llvm/ADT/GenericUniformityImpl.h"
 #include "llvm/Analysis/CycleAnalysis.h"
+#include "llvm/Analysis/PostDominators.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/InstIterator.h"
@@ -114,9 +115,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<
 llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
                                                  FunctionAnalysisManager &FAM) {
   auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
+  auto &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F);
   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
   auto &CI = FAM.getResult<CycleAnalysis>(F);
-  UniformityInfo UI{DT, CI, &TTI};
+  UniformityInfo UI{DT, PDT, CI, &TTI};
   // Skip computation if we can assume everything is uniform.
   if (TTI.hasBranchDivergence(&F))
     UI.compute();
@@ -148,6 +150,7 @@ UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {}
 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
                       "Uniformity Analysis", true, true)
 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
 INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
@@ -156,6 +159,7 @@ INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesAll();
   AU.addRequired<DominatorTreeWrapperPass>();
+  AU.addRequired<PostDominatorTreeWrapperPass>();
   AU.addRequiredTransitive<CycleInfoWrapperPass>();
   AU.addRequired<TargetTransformInfoWrapperPass>();
 }
@@ -163,11 +167,13 @@ void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
 bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
   auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
   auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+  auto &pdomTree = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
   auto &targetTransformInfo =
       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
 
   m_function = &F;
-  m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
+  m_uniformityInfo =
+      UniformityInfo{domTree, pdomTree, cycleInfo, &targetTransformInfo};
 
   // Skip computation if we can assume everything is uniform.
   if (targetTransformInfo.hasBranchDivergence(m_function))
diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
index e4b82ce83fda6..b87f8357ecfa8 100644
--- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
+++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
@@ -11,6 +11,7 @@
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/CodeGen/MachineCycleAnalysis.h"
 #include "llvm/CodeGen/MachineDominators.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/MachineSSAContext.h"
 #include "llvm/CodeGen/TargetInstrInfo.h"
@@ -156,9 +157,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<
 
 MachineUniformityInfo llvm::computeMachineUniformityInfo(
     MachineFunction &F, const MachineCycleInfo &cycleInfo,
-    const MachineDominatorTree &domTree, bool HasBranchDivergence) {
+    const MachineDominatorTree &domTree,
+    const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence) {
   assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
-  MachineUniformityInfo UI(domTree, cycleInfo);
+  MachineUniformityInfo UI(domTree, pdomTree, cycleInfo);
   if (HasBranchDivergence)
     UI.compute();
   return UI;
@@ -184,12 +186,13 @@ MachineUniformityAnalysis::Result
 MachineUniformityAnalysis::run(MachineFunction &MF,
                                MachineFunctionAnalysisManager &MFAM) {
   auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(MF);
+  auto &PDomTree = MFAM.getResult<MachinePostDominatorTreeAnalysis>(MF);
   auto &CI = MFAM.getResult<MachineCycleAnalysis>(MF);
   auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
                   .getManager();
   auto &F = MF.getFunction();
   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
-  return computeMachineUniformityInfo(MF, CI, DomTree,
+  return computeMachineUniformityInfo(MF, CI, DomTree, PDomTree,
                                       TTI.hasBranchDivergence(&F));
 }
 
@@ -215,6 +218,7 @@ INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
                       "Machine Uniformity Info Analysis", false, true)
 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
                     "Machine Uniformity Info Analysis", false, true)
 
@@ -222,15 +226,18 @@ void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesAll();
   AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
   AU.addRequired<MachineDominatorTreeWrapperPass>();
+  AU.addRequired<MachinePostDominatorTreeWrapperPass>();
   MachineFunctionPass::getAnalysisUsage(AU);
 }
 
 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
   auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
+  auto &PDomTree =
+      getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
   auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
   // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
   // default NoTTI
-  UI = computeMachineUniformityInfo(MF, CI, DomTree, true);
+  UI = computeMachineUniformityInfo(MF, CI, DomTree, PDomTree, true);
   return false;
 }
 
diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll
new file mode 100644
index 0000000000000..df949a86635c4
--- /dev/null
+++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll
@@ -0,0 +1,78 @@
+;
+; RUN: opt -mtriple amdgcn-- -passes='print<uniformity>' -disable-output %s 2>&1 | FileCheck %s
+;
+; This is to test an if-then-else case with some unmerged basic blocks
+; (https://github.com/llvm/llvm-project/issues/137277)
+;
+;      Entry (div.cond)
+;      /   \
+;     B0   B3
+;     |    |
+;     B1   B4
+;     |    |
+;     B2   B5
+;      \  /
+;       B6 (phi: divergent)
+;
+
+
+; CHECK-LABEL:  'test_ctrl_divergence':
+; CHECK-LABEL:  BLOCK Entry
+; CHECK:  DIVERGENT:   %div.cond = icmp eq i32 %tid, 0
+; CHECK:  DIVERGENT:   br i1 %div.cond, label %B3, label %B0
+;
+; CHECK-LABEL:  BLOCK B6
+; CHECK:  DIVERGENT:   %div_a = phi i32 [ %a0, %B2 ], [ %a1, %B5 ]
+; CHECK:  DIVERGENT:   %div_b = phi i32 [ %b0, %B2 ], [ %b1, %B5 ]
+; CHECK:  DIVERGENT:   %div_c = phi i32 [ %c0, %B2 ], [ %c1, %B5 ]
+
+
+define amdgpu_kernel void @test_ctrl_divergence(i32 %a, i32 %b, i32 %c, i32 %d) {
+Entry:
+  %tid = call i32 @llvm.amdgcn.workitem.id.x()
+  %div.cond = icmp eq i32 %tid, 0
+  br i1 %div.cond, label %B3, label %B0 ; divergent branch
+
+B0:
+  %a0 = add i32 %a, 1
+  br label %B1
+
+B1:
+  %...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/140013


More information about the llvm-commits mailing list