[llvm] 0a170eb - [Uniformity] Propagate divergence only along divergent outputs.

Sameer Sahasrabuddhe via llvm-commits llvm-commits at lists.llvm.org
Tue May 16 19:18:34 PDT 2023


Author: Sameer Sahasrabuddhe
Date: 2023-05-17T07:47:43+05:30
New Revision: 0a170eb7866b72a9aae0498c20bdd4befde9fce5

URL: https://github.com/llvm/llvm-project/commit/0a170eb7866b72a9aae0498c20bdd4befde9fce5
DIFF: https://github.com/llvm/llvm-project/commit/0a170eb7866b72a9aae0498c20bdd4befde9fce5.diff

LOG: [Uniformity] Propagate divergence only along divergent outputs.

When an instruction is determined to be divergent, not all its outputs are
divergent. The users of only divergent outputs should now be examined for
divergence.

Also, replaced a repeating pattern of "if new divergent instruction, then add to
worklist" by combining it into a single function. This does not cause any change
in functionality.

Reviewed By: foad, arsenm

Differential Revision: https://reviews.llvm.org/D150636

Added: 
    

Modified: 
    llvm/include/llvm/ADT/GenericUniformityImpl.h
    llvm/lib/Analysis/UniformityAnalysis.cpp
    llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
    llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h
index 75a33e19210fe..71935d12ea03f 100644
--- a/llvm/include/llvm/ADT/GenericUniformityImpl.h
+++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h
@@ -355,10 +355,15 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
   /// \brief Mark \p UniVal as a value that is always uniform.
   void addUniformOverride(const InstructionT &Instr);
 
-  /// \brief Mark \p DivVal as a value that is always divergent.
+  /// \brief Examine \p I for divergent outputs and add to the worklist.
+  void markDivergent(const InstructionT &I);
+
+  /// \brief Mark \p DivVal as a divergent value.
   /// \returns Whether the tracked divergence state of \p DivVal changed.
-  bool markDivergent(const InstructionT &I);
   bool markDivergent(ConstValueRefT DivVal);
+
+  /// \brief Mark outputs of \p Instr as divergent.
+  /// \returns Whether the tracked divergence state of any output has changed.
   bool markDefsDivergent(const InstructionT &Instr);
 
   /// \brief Propagate divergence to all instructions in the region.
@@ -774,21 +779,23 @@ auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
 }
 
 template <typename ContextT>
-bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
+void GenericUniformityAnalysisImpl<ContextT>::markDivergent(
     const InstructionT &I) {
+  if (isAlwaysUniform(I))
+    return;
+  bool Marked = false;
   if (I.isTerminator()) {
-    if (DivergentTermBlocks.insert(I.getParent()).second) {
+    Marked = DivergentTermBlocks.insert(I.getParent()).second;
+    if (Marked) {
       LLVM_DEBUG(dbgs() << "marked divergent term block: "
                         << Context.print(I.getParent()) << "\n");
-      return true;
     }
-    return false;
+  } else {
+    Marked = markDefsDivergent(I);
   }
 
-  if (isAlwaysUniform(I))
-    return false;
-
-  return markDefsDivergent(I);
+  if (Marked)
+    Worklist.push_back(&I);
 }
 
 template <typename ContextT>
@@ -828,8 +835,7 @@ void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence(
   for (auto *Exit : Exits) {
     for (auto &Phi : Exit->phis()) {
       if (usesValueFromCycle(Phi, DefCycle)) {
-        if (markDivergent(Phi))
-          Worklist.push_back(&Phi);
+        markDivergent(Phi);
       }
     }
   }
@@ -889,8 +895,7 @@ void GenericUniformityAnalysisImpl<ContextT>::taintAndPushAllDefs(
     if (I.isTerminator())
       break;
 
-    if (markDivergent(I))
-      Worklist.push_back(&I);
+    markDivergent(I);
   }
 }
 
@@ -910,8 +915,7 @@ void GenericUniformityAnalysisImpl<ContextT>::taintAndPushPhiNodes(
     // https://reviews.llvm.org/D19013
     if (ContextT::isConstantOrUndefValuePhi(Phi))
       continue;
-    if (markDivergent(Phi))
-      Worklist.push_back(&Phi);
+    markDivergent(Phi);
   }
 }
 

diff  --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp
index fad88bb6f2c98..60d6bb881940a 100644
--- a/llvm/lib/Analysis/UniformityAnalysis.cpp
+++ b/llvm/lib/Analysis/UniformityAnalysis.cpp
@@ -27,7 +27,7 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
 template <>
 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
     const Instruction &Instr) {
-  return markDivergent(&Instr);
+  return markDivergent(cast<Value>(&Instr));
 }
 
 template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
@@ -49,9 +49,7 @@ void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
     const Value *V) {
   for (const auto *User : V->users()) {
     if (const auto *UserInstr = dyn_cast<const Instruction>(User)) {
-      if (markDivergent(*UserInstr)) {
-        Worklist.push_back(UserInstr);
-      }
+      markDivergent(*UserInstr);
     }
   }
 }
@@ -88,8 +86,7 @@ void llvm::GenericUniformityAnalysisImpl<
     auto *UserInstr = cast<Instruction>(User);
     if (DefCycle.contains(UserInstr->getParent()))
       continue;
-    if (markDivergent(*UserInstr))
-      Worklist.push_back(UserInstr);
+    markDivergent(*UserInstr);
   }
 }
 

diff  --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
index 693c64eabce9c..cc8cdaff9f0ed 100644
--- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
+++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
@@ -62,8 +62,7 @@ void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
       }
 
       if (uniformity == InstructionUniformity::NeverUniform) {
-        if (markDivergent(instr))
-          Worklist.push_back(&instr);
+        markDivergent(instr);
       }
     }
   }
@@ -72,10 +71,10 @@ void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
 template <>
 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
     Register Reg) {
+  assert(isDivergent(Reg));
   const auto &RegInfo = F.getRegInfo();
   for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
-    if (markDivergent(UserInstr))
-      Worklist.push_back(&UserInstr);
+    markDivergent(UserInstr);
   }
 }
 
@@ -86,8 +85,11 @@ void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
   if (Instr.isTerminator())
     return;
   for (const MachineOperand &op : Instr.operands()) {
-    if (op.isReg() && op.isDef() && op.getReg().isVirtual())
-      pushUsers(op.getReg());
+    if (!op.isReg() || !op.isDef())
+      continue;
+    auto Reg = op.getReg();
+    if (isDivergent(Reg))
+      pushUsers(Reg);
   }
 }
 
@@ -128,8 +130,7 @@ void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
     for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
       if (DefCycle.contains(UserInstr.getParent()))
         continue;
-      if (markDivergent(UserInstr))
-        Worklist.push_back(&UserInstr);
+      markDivergent(UserInstr);
     }
   }
 }

diff  --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir
index bae97172f8ea2..6a0b5bb107bf2 100644
--- a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir
+++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir
@@ -97,7 +97,6 @@ body:             |
 
 ...
 
-# FIXME :: BELOW INLINE ASM SHOULD BE DIVERGENT
 ---
 name:            asm_mixed_sgpr_vgpr
 registers:
@@ -116,7 +115,9 @@ body:             |
     ; CHECK-LABEL: MachineUniformityInfo for function: asm_mixed_sgpr_vgpr
     ; CHECK: DIVERGENT: %0:
     ; CHECK: DIVERGENT: %3:
+    ; CHECK-NOT: DIVERGENT: %1:
     ; CHECK: DIVERGENT: %2:
+    ; CHECK-NOT: DIVERGENT: %4:
     ; CHECK: DIVERGENT: %5:
     %0:_(s32) = COPY $vgpr0
     %6:_(p1) = G_IMPLICIT_DEF


        


More information about the llvm-commits mailing list