[llvm] b0f0dd2 - [LLVM][Uniformity] Propagate temporal divergence explicitly

Sameer Sahasrabuddhe via llvm-commits llvm-commits at lists.llvm.org
Mon May 15 07:48:44 PDT 2023


Author: Sameer Sahasrabuddhe
Date: 2023-05-15T20:17:43+05:30
New Revision: b0f0dd2554c726e5192ad8c98fb7a2f08c37994c

URL: https://github.com/llvm/llvm-project/commit/b0f0dd2554c726e5192ad8c98fb7a2f08c37994c
DIFF: https://github.com/llvm/llvm-project/commit/b0f0dd2554c726e5192ad8c98fb7a2f08c37994c.diff

LOG: [LLVM][Uniformity] Propagate temporal divergence explicitly

At a cycle C with divergent exits, UA was using a naive traversal of the exiting
edges to locate blocks that may use values defined inside C. But this traversal
fails when it encounters a cycle. This is now replaced with a much simpler
propagation that iterates over every instruction in C and checks any uses that
are outside C. But such an iteration can be expensive when C is very large; the
original strategy may need to be reconsidered if there is a regression in
compilation times.

Also fixed lit tests that should have originally caught the missed propagation
of temporal divergence.

Reviewed By: foad

Differential Revision: https://reviews.llvm.org/D149646

Added: 
    

Modified: 
    llvm/include/llvm/ADT/GenericUniformityImpl.h
    llvm/lib/Analysis/UniformityAnalysis.cpp
    llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
    llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h
index 7eff61b26ba56..4b595a102fabf 100644
--- a/llvm/include/llvm/ADT/GenericUniformityImpl.h
+++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h
@@ -451,13 +451,12 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
   void propagateCycleExitDivergence(const BlockT &DivExit,
                                     const CycleT &DivCycle);
 
-  /// \brief Internal implementation function for propagateCycleExitDivergence.
-  void analyzeCycleExitDivergence(const CycleT &OuterDivCycle);
+  /// Mark as divergent all external uses of values defined in \p DefCycle.
+  void analyzeCycleExitDivergence(const CycleT &DefCycle);
 
-  /// \brief Mark all instruction as divergent that use a value defined in \p
-  /// OuterDivCycle. Push their users on the worklist.
-  void analyzeTemporalDivergence(const InstructionT &I,
-                                 const CycleT &OuterDivCycle);
+  /// \brief Mark as divergent all uses of \p I that are outside \p DefCycle.
+  void propagateTemporalDivergence(const InstructionT &I,
+                                   const CycleT &DefCycle);
 
   /// \brief Push all users of \p Val (in the region) to the worklist.
   void pushUsers(const InstructionT &I);
@@ -809,106 +808,39 @@ void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride(
   UniformOverrides.insert(&Instr);
 }
 
-template <typename ContextT>
-void GenericUniformityAnalysisImpl<ContextT>::analyzeTemporalDivergence(
-    const InstructionT &I, const CycleT &OuterDivCycle) {
-  if (isDivergent(I))
-    return;
-
-  LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << Context.print(&I)
-                    << "\n");
-  if (isAlwaysUniform(I))
-    return;
-
-  if (!usesValueFromCycle(I, OuterDivCycle))
-    return;
-
-  if (markDivergent(I))
-    Worklist.push_back(&I);
-}
-
-// Mark all external users of values defined inside \param
-// OuterDivCycle as divergent.
+// Mark as divergent all external uses of values defined in \p DefCycle.
+//
+// A value V defined by a block B inside \p DefCycle may be used outside the
+// cycle only if the use is a PHI in some exit block, or B dominates some exit
+// block. Thus, we check uses as follows:
 //
-// This follows all live out edges wherever they may lead. Potential
-// users of values defined inside DivCycle could be anywhere in the
-// dominance region of DivCycle (including its fringes for phi nodes).
-// A cycle C dominates a block B iff every path from the entry block
-// to B must pass through a block contained in C. If C is a reducible
-// cycle (or natural loop), C dominates B iff the header of C
-// dominates B. But in general, we iteratively examine cycle cycle
-// exits and their successors.
+// - Check all PHIs in all exit blocks for inputs defined inside \p DefCycle.
+// - For every block B inside \p DefCycle that dominates at least one exit
+//   block, check all uses outside \p DefCycle.
+//
+// FIXME: This function does not distinguish between divergent and uniform
+// exits. For each divergent exit, only the values that are live at that exit
+// need to be propagated as divergent at their use outside the cycle.
 template <typename ContextT>
 void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence(
-    const CycleT &OuterDivCycle) {
-  // Set of blocks that are dominated by the cycle, i.e., each is only
-  // reachable from paths that pass through the cycle.
-  SmallPtrSet<BlockT *, 16> DomRegion;
-
-  // The boundary of DomRegion, formed by blocks that are not
-  // dominated by the cycle.
-  SmallVector<BlockT *> DomFrontier;
-  OuterDivCycle.getExitBlocks(DomFrontier);
-
-  // Returns true if BB is dominated by the cycle.
-  auto isInDomRegion = [&](BlockT *BB) {
-    for (auto *P : predecessors(BB)) {
-      if (OuterDivCycle.contains(P))
-        continue;
-      if (DomRegion.count(P))
-        continue;
-      return false;
-    }
-    return true;
-  };
-
-  // Keep advancing the frontier along successor edges, while
-  // promoting blocks to DomRegion.
-  while (true) {
-    bool Promoted = false;
-    SmallVector<BlockT *> Temp;
-    for (auto *W : DomFrontier) {
-      if (!isInDomRegion(W)) {
-        Temp.push_back(W);
-        continue;
-      }
-      DomRegion.insert(W);
-      Promoted = true;
-      for (auto *Succ : successors(W)) {
-        if (DomRegion.contains(Succ))
-          continue;
-        Temp.push_back(Succ);
+    const CycleT &DefCycle) {
+  SmallVector<BlockT *> Exits;
+  DefCycle.getExitBlocks(Exits);
+  for (auto *Exit : Exits) {
+    for (auto &Phi : Exit->phis()) {
+      if (usesValueFromCycle(Phi, DefCycle)) {
+        if (markDivergent(Phi))
+          Worklist.push_back(&Phi);
       }
     }
-    if (!Promoted)
-      break;
-
-    // Restore the set property for the temporary vector
-    llvm::sort(Temp);
-    Temp.erase(std::unique(Temp.begin(), Temp.end()), Temp.end());
-
-    DomFrontier = Temp;
   }
 
-  // At DomFrontier, only the PHI nodes are affected by temporal
-  // divergence.
-  for (const auto *UserBlock : DomFrontier) {
-    LLVM_DEBUG(dbgs() << "Analyze phis after cycle exit: "
-                      << Context.print(UserBlock) << "\n");
-    for (const auto &Phi : UserBlock->phis()) {
-      LLVM_DEBUG(dbgs() << "  " << Context.print(&Phi) << "\n");
-      analyzeTemporalDivergence(Phi, OuterDivCycle);
-    }
-  }
-
-  // All instructions inside the dominance region are affected by
-  // temporal divergence.
-  for (const auto *UserBlock : DomRegion) {
-    LLVM_DEBUG(dbgs() << "Analyze non-phi users after cycle exit: "
-                      << Context.print(UserBlock) << "\n");
-    for (const auto &I : *UserBlock) {
-      LLVM_DEBUG(dbgs() << "  " << Context.print(&I) << "\n");
-      analyzeTemporalDivergence(I, OuterDivCycle);
+  for (auto *BB : DefCycle.blocks()) {
+    if (!llvm::any_of(Exits,
+                     [&](BlockT *Exit) { return DT.dominates(BB, Exit); }))
+      continue;
+    for (auto &II : *BB) {
+      propagateTemporalDivergence(II, DefCycle);
     }
   }
 }

diff  --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp
index 13a9c2b7e4438..af766ef68bbba 100644
--- a/llvm/lib/Analysis/UniformityAnalysis.cpp
+++ b/llvm/lib/Analysis/UniformityAnalysis.cpp
@@ -78,6 +78,21 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
   return false;
 }
 
+template <>
+void llvm::GenericUniformityAnalysisImpl<
+    SSAContext>::propagateTemporalDivergence(const Instruction &I,
+                                             const Cycle &DefCycle) {
+  if (isDivergent(I))
+    return;
+  for (auto *User : I.users()) {
+    auto *UserInstr = cast<Instruction>(User);
+    if (DefCycle.contains(UserInstr->getParent()))
+      continue;
+    if (markDivergent(*UserInstr))
+      Worklist.push_back(UserInstr);
+  }
+}
+
 template <>
 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
     const Use &U) const {

diff  --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
index ef67bae1c1af0..22f38ae349b86 100644
--- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
+++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
@@ -113,6 +113,28 @@ bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
   return false;
 }
 
+template <>
+void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
+    propagateTemporalDivergence(const MachineInstr &I,
+                                const MachineCycle &DefCycle) {
+  const auto &RegInfo = F.getRegInfo();
+  for (auto &Op : I.operands()) {
+    if (!Op.isReg() || !Op.isDef())
+      continue;
+    if (!Op.getReg().isVirtual())
+      continue;
+    auto Reg = Op.getReg();
+    if (isDivergent(Reg))
+      continue;
+    for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
+      if (DefCycle.contains(UserInstr.getParent()))
+        continue;
+      if (markDivergent(UserInstr))
+        Worklist.push_back(&UserInstr);
+    }
+  }
+}
+
 template <>
 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
     const MachineOperand &U) const {

diff  --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll
index 842636aa952f0..6bab909276cb6 100644
--- a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll
+++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll
@@ -14,13 +14,67 @@ entry:
 H:
   %uni.merge.h = phi i32 [ 0, %entry ], [ %uni.inc, %H ]
   %uni.inc = add i32 %uni.merge.h, 1
+; CHECK: DIVERGENT: %div.exitx =
   %div.exitx = icmp slt i32 %tid, 0
+; CHECK: DIVERGENT: br i1 %div.exitx,
   br i1 %div.exitx, label %X, label %H ; divergent branch
+
+X:
+; CHECK: DIVERGENT:  %div.user =
+  %div.user = add i32 %uni.inc, 5
+  ret void
+}
+
+define amdgpu_kernel void @phi_at_exit(i32 %n, i32 %a, i32 %b) #0 {
+; CHECK-LABEL: for function 'phi_at_exit':
+; CHECK-NOT: DIVERGENT: %uni.
+; CHECK-NOT: DIVERGENT: br i1 %uni.
+
+entry:
+  %tid = call i32 @llvm.amdgcn.workitem.id.x()
+  %uni.cond = icmp slt i32 %a, 0
+  br i1 %uni.cond, label %H, label %X
+
+H:
+  %uni.merge.h = phi i32 [ 0, %entry ], [ %uni.inc, %H ]
+  %uni.inc = add i32 %uni.merge.h, 1
 ; CHECK: DIVERGENT: %div.exitx =
+  %div.exitx = icmp slt i32 %tid, 0
 ; CHECK: DIVERGENT: br i1 %div.exitx,
+  br i1 %div.exitx, label %X, label %H ; divergent branch
 
 X:
-  %div.user = add i32 %uni.inc, 5
+; CHECK: DIVERGENT:  %div.phi =
+  %div.phi = phi i32 [ 0, %entry], [ %uni.inc, %H ]
+  %div.user = add i32 %div.phi, 5
+  ret void
+}
+
+define amdgpu_kernel void @phi_after_exit(i32 %n, i32 %a, i32 %b) #0 {
+; CHECK-LABEL: for function 'phi_after_exit':
+; CHECK-NOT: DIVERGENT: %uni.
+; CHECK-NOT: DIVERGENT: br i1 %uni.
+
+entry:
+  %tid = call i32 @llvm.amdgcn.workitem.id.x()
+  %uni.cond = icmp slt i32 %a, 0
+  br i1 %uni.cond, label %H, label %Y
+
+H:
+  %uni.merge.h = phi i32 [ 0, %entry ], [ %uni.inc, %H ]
+  %uni.inc = add i32 %uni.merge.h, 1
+; CHECK: DIVERGENT: %div.exitx =
+  %div.exitx = icmp slt i32 %tid, 0
+; CHECK: DIVERGENT: br i1 %div.exitx,
+  br i1 %div.exitx, label %X, label %H ; divergent branch
+
+X:
+  br label %Y
+
+Y:
+; CHECK: DIVERGENT:  %div.phi =
+  %div.phi = phi i32 [ 0, %entry], [ %uni.inc, %X ]
+  %div.user = add i32 %div.phi, 5
   ret void
 }
 
@@ -56,7 +110,8 @@ Y:
 }
 
 
-; temporal-uniform use of a valud, definition and users are carried by a surrounding divergent loop
+; temporal-uniform use of a value, definition and users are carried by a
+; surrounding divergent loop
 define amdgpu_kernel void @temporal_uniform_indivloop(i32 %n, i32 %a, i32 %b) #0 {
 ; CHECK-LABEL: for function 'temporal_uniform_indivloop':
 ; CHECK-NOT: DIVERGENT: %uni.
@@ -85,6 +140,7 @@ X:
 Y:
   %div.alsouser = add i32 %uni.inc, 5
   ret void
+; CHECK: DIVERGENT: %div.alsouser =
 }
 
 
@@ -113,6 +169,7 @@ X:
 G:
   %div.user = add i32 %uni.inc, 5
   br i1 %uni.cond, label %G, label %Y
+; CHECK: DIVERGENT: %div.user =
 
 Y:
   ret void
@@ -127,10 +184,13 @@ define amdgpu_kernel void @temporal_diverge_loopuser_nested(i32 %n, i32 %a, i32
 entry:
   %tid = call i32 @llvm.amdgcn.workitem.id.x()
   %uni.cond = icmp slt i32 %a, 0
+  br label %G
+
+G:
   br label %H
 
 H:
-  %uni.merge.h = phi i32 [ 0, %entry ], [ %uni.inc, %H ]
+  %uni.merge.h = phi i32 [ 0, %G ], [ %uni.inc, %H ]
   %uni.inc = add i32 %uni.merge.h, 1
   %div.exitx = icmp slt i32 %tid, 0
   br i1 %div.exitx, label %X, label %H ; divergent branch
@@ -138,17 +198,14 @@ H:
 ; CHECK: DIVERGENT: br i1 %div.exitx,
 
 X:
-  br label %G
-
-G:
+; CHECK: DIVERGENT: %div.user =
   %div.user = add i32 %uni.inc, 5
-  br i1 %uni.cond, label %G, label %Y
+  br i1 %uni.cond, label %X, label %G
 
 Y:
   ret void
 }
 
-
 declare i32 @llvm.amdgcn.workitem.id.x() #0
 
 attributes #0 = { nounwind readnone }


        


More information about the llvm-commits mailing list