[llvm] b0f0dd2 - [LLVM][Uniformity] Propagate temporal divergence explicitly
Sameer Sahasrabuddhe via llvm-commits
llvm-commits at lists.llvm.org
Mon May 15 07:48:44 PDT 2023
Author: Sameer Sahasrabuddhe
Date: 2023-05-15T20:17:43+05:30
New Revision: b0f0dd2554c726e5192ad8c98fb7a2f08c37994c
URL: https://github.com/llvm/llvm-project/commit/b0f0dd2554c726e5192ad8c98fb7a2f08c37994c
DIFF: https://github.com/llvm/llvm-project/commit/b0f0dd2554c726e5192ad8c98fb7a2f08c37994c.diff
LOG: [LLVM][Uniformity] Propagate temporal divergence explicitly
At a cycle C with divergent exits, UA was using a naive traversal of the exiting
edges to locate blocks that may use values defined inside C. But this traversal
fails when it encounters a cycle. This is now replaced with a much simpler
propagation that iterates over every instruction in C and checks any uses that
are outside C. But such an iteration can be expensive when C is very large; the
original strategy may need to be reconsidered if there is a regression in
compilation times.
Also fixed lit tests that should have originally caught the missed propagation
of temporal divergence.
Reviewed By: foad
Differential Revision: https://reviews.llvm.org/D149646
Added:
Modified:
llvm/include/llvm/ADT/GenericUniformityImpl.h
llvm/lib/Analysis/UniformityAnalysis.cpp
llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h
index 7eff61b26ba56..4b595a102fabf 100644
--- a/llvm/include/llvm/ADT/GenericUniformityImpl.h
+++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h
@@ -451,13 +451,12 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
void propagateCycleExitDivergence(const BlockT &DivExit,
const CycleT &DivCycle);
- /// \brief Internal implementation function for propagateCycleExitDivergence.
- void analyzeCycleExitDivergence(const CycleT &OuterDivCycle);
+ /// Mark as divergent all external uses of values defined in \p DefCycle.
+ void analyzeCycleExitDivergence(const CycleT &DefCycle);
- /// \brief Mark all instruction as divergent that use a value defined in \p
- /// OuterDivCycle. Push their users on the worklist.
- void analyzeTemporalDivergence(const InstructionT &I,
- const CycleT &OuterDivCycle);
+ /// \brief Mark as divergent all uses of \p I that are outside \p DefCycle.
+ void propagateTemporalDivergence(const InstructionT &I,
+ const CycleT &DefCycle);
/// \brief Push all users of \p Val (in the region) to the worklist.
void pushUsers(const InstructionT &I);
@@ -809,106 +808,39 @@ void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride(
UniformOverrides.insert(&Instr);
}
-template <typename ContextT>
-void GenericUniformityAnalysisImpl<ContextT>::analyzeTemporalDivergence(
- const InstructionT &I, const CycleT &OuterDivCycle) {
- if (isDivergent(I))
- return;
-
- LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << Context.print(&I)
- << "\n");
- if (isAlwaysUniform(I))
- return;
-
- if (!usesValueFromCycle(I, OuterDivCycle))
- return;
-
- if (markDivergent(I))
- Worklist.push_back(&I);
-}
-
-// Mark all external users of values defined inside \param
-// OuterDivCycle as divergent.
+// Mark as divergent all external uses of values defined in \p DefCycle.
+//
+// A value V defined by a block B inside \p DefCycle may be used outside the
+// cycle only if the use is a PHI in some exit block, or B dominates some exit
+// block. Thus, we check uses as follows:
//
-// This follows all live out edges wherever they may lead. Potential
-// users of values defined inside DivCycle could be anywhere in the
-// dominance region of DivCycle (including its fringes for phi nodes).
-// A cycle C dominates a block B iff every path from the entry block
-// to B must pass through a block contained in C. If C is a reducible
-// cycle (or natural loop), C dominates B iff the header of C
-// dominates B. But in general, we iteratively examine cycle cycle
-// exits and their successors.
+// - Check all PHIs in all exit blocks for inputs defined inside \p DefCycle.
+// - For every block B inside \p DefCycle that dominates at least one exit
+// block, check all uses outside \p DefCycle.
+//
+// FIXME: This function does not distinguish between divergent and uniform
+// exits. For each divergent exit, only the values that are live at that exit
+// need to be propagated as divergent at their use outside the cycle.
template <typename ContextT>
void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence(
- const CycleT &OuterDivCycle) {
- // Set of blocks that are dominated by the cycle, i.e., each is only
- // reachable from paths that pass through the cycle.
- SmallPtrSet<BlockT *, 16> DomRegion;
-
- // The boundary of DomRegion, formed by blocks that are not
- // dominated by the cycle.
- SmallVector<BlockT *> DomFrontier;
- OuterDivCycle.getExitBlocks(DomFrontier);
-
- // Returns true if BB is dominated by the cycle.
- auto isInDomRegion = [&](BlockT *BB) {
- for (auto *P : predecessors(BB)) {
- if (OuterDivCycle.contains(P))
- continue;
- if (DomRegion.count(P))
- continue;
- return false;
- }
- return true;
- };
-
- // Keep advancing the frontier along successor edges, while
- // promoting blocks to DomRegion.
- while (true) {
- bool Promoted = false;
- SmallVector<BlockT *> Temp;
- for (auto *W : DomFrontier) {
- if (!isInDomRegion(W)) {
- Temp.push_back(W);
- continue;
- }
- DomRegion.insert(W);
- Promoted = true;
- for (auto *Succ : successors(W)) {
- if (DomRegion.contains(Succ))
- continue;
- Temp.push_back(Succ);
+ const CycleT &DefCycle) {
+ SmallVector<BlockT *> Exits;
+ DefCycle.getExitBlocks(Exits);
+ for (auto *Exit : Exits) {
+ for (auto &Phi : Exit->phis()) {
+ if (usesValueFromCycle(Phi, DefCycle)) {
+ if (markDivergent(Phi))
+ Worklist.push_back(&Phi);
}
}
- if (!Promoted)
- break;
-
- // Restore the set property for the temporary vector
- llvm::sort(Temp);
- Temp.erase(std::unique(Temp.begin(), Temp.end()), Temp.end());
-
- DomFrontier = Temp;
}
- // At DomFrontier, only the PHI nodes are affected by temporal
- // divergence.
- for (const auto *UserBlock : DomFrontier) {
- LLVM_DEBUG(dbgs() << "Analyze phis after cycle exit: "
- << Context.print(UserBlock) << "\n");
- for (const auto &Phi : UserBlock->phis()) {
- LLVM_DEBUG(dbgs() << " " << Context.print(&Phi) << "\n");
- analyzeTemporalDivergence(Phi, OuterDivCycle);
- }
- }
-
- // All instructions inside the dominance region are affected by
- // temporal divergence.
- for (const auto *UserBlock : DomRegion) {
- LLVM_DEBUG(dbgs() << "Analyze non-phi users after cycle exit: "
- << Context.print(UserBlock) << "\n");
- for (const auto &I : *UserBlock) {
- LLVM_DEBUG(dbgs() << " " << Context.print(&I) << "\n");
- analyzeTemporalDivergence(I, OuterDivCycle);
+ for (auto *BB : DefCycle.blocks()) {
+ if (!llvm::any_of(Exits,
+ [&](BlockT *Exit) { return DT.dominates(BB, Exit); }))
+ continue;
+ for (auto &II : *BB) {
+ propagateTemporalDivergence(II, DefCycle);
}
}
}
diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp
index 13a9c2b7e4438..af766ef68bbba 100644
--- a/llvm/lib/Analysis/UniformityAnalysis.cpp
+++ b/llvm/lib/Analysis/UniformityAnalysis.cpp
@@ -78,6 +78,21 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
return false;
}
+template <>
+void llvm::GenericUniformityAnalysisImpl<
+ SSAContext>::propagateTemporalDivergence(const Instruction &I,
+ const Cycle &DefCycle) {
+ if (isDivergent(I))
+ return;
+ for (auto *User : I.users()) {
+ auto *UserInstr = cast<Instruction>(User);
+ if (DefCycle.contains(UserInstr->getParent()))
+ continue;
+ if (markDivergent(*UserInstr))
+ Worklist.push_back(UserInstr);
+ }
+}
+
template <>
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
const Use &U) const {
diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
index ef67bae1c1af0..22f38ae349b86 100644
--- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
+++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
@@ -113,6 +113,28 @@ bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
return false;
}
+template <>
+void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
+ propagateTemporalDivergence(const MachineInstr &I,
+ const MachineCycle &DefCycle) {
+ const auto &RegInfo = F.getRegInfo();
+ for (auto &Op : I.operands()) {
+ if (!Op.isReg() || !Op.isDef())
+ continue;
+ if (!Op.getReg().isVirtual())
+ continue;
+ auto Reg = Op.getReg();
+ if (isDivergent(Reg))
+ continue;
+ for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
+ if (DefCycle.contains(UserInstr.getParent()))
+ continue;
+ if (markDivergent(UserInstr))
+ Worklist.push_back(&UserInstr);
+ }
+ }
+}
+
template <>
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
const MachineOperand &U) const {
diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll
index 842636aa952f0..6bab909276cb6 100644
--- a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll
+++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/temporal_diverge.ll
@@ -14,13 +14,67 @@ entry:
H:
%uni.merge.h = phi i32 [ 0, %entry ], [ %uni.inc, %H ]
%uni.inc = add i32 %uni.merge.h, 1
+; CHECK: DIVERGENT: %div.exitx =
%div.exitx = icmp slt i32 %tid, 0
+; CHECK: DIVERGENT: br i1 %div.exitx,
br i1 %div.exitx, label %X, label %H ; divergent branch
+
+X:
+; CHECK: DIVERGENT: %div.user =
+ %div.user = add i32 %uni.inc, 5
+ ret void
+}
+
+define amdgpu_kernel void @phi_at_exit(i32 %n, i32 %a, i32 %b) #0 {
+; CHECK-LABEL: for function 'phi_at_exit':
+; CHECK-NOT: DIVERGENT: %uni.
+; CHECK-NOT: DIVERGENT: br i1 %uni.
+
+entry:
+ %tid = call i32 @llvm.amdgcn.workitem.id.x()
+ %uni.cond = icmp slt i32 %a, 0
+ br i1 %uni.cond, label %H, label %X
+
+H:
+ %uni.merge.h = phi i32 [ 0, %entry ], [ %uni.inc, %H ]
+ %uni.inc = add i32 %uni.merge.h, 1
; CHECK: DIVERGENT: %div.exitx =
+ %div.exitx = icmp slt i32 %tid, 0
; CHECK: DIVERGENT: br i1 %div.exitx,
+ br i1 %div.exitx, label %X, label %H ; divergent branch
X:
- %div.user = add i32 %uni.inc, 5
+; CHECK: DIVERGENT: %div.phi =
+ %div.phi = phi i32 [ 0, %entry], [ %uni.inc, %H ]
+ %div.user = add i32 %div.phi, 5
+ ret void
+}
+
+define amdgpu_kernel void @phi_after_exit(i32 %n, i32 %a, i32 %b) #0 {
+; CHECK-LABEL: for function 'phi_after_exit':
+; CHECK-NOT: DIVERGENT: %uni.
+; CHECK-NOT: DIVERGENT: br i1 %uni.
+
+entry:
+ %tid = call i32 @llvm.amdgcn.workitem.id.x()
+ %uni.cond = icmp slt i32 %a, 0
+ br i1 %uni.cond, label %H, label %Y
+
+H:
+ %uni.merge.h = phi i32 [ 0, %entry ], [ %uni.inc, %H ]
+ %uni.inc = add i32 %uni.merge.h, 1
+; CHECK: DIVERGENT: %div.exitx =
+ %div.exitx = icmp slt i32 %tid, 0
+; CHECK: DIVERGENT: br i1 %div.exitx,
+ br i1 %div.exitx, label %X, label %H ; divergent branch
+
+X:
+ br label %Y
+
+Y:
+; CHECK: DIVERGENT: %div.phi =
+ %div.phi = phi i32 [ 0, %entry], [ %uni.inc, %X ]
+ %div.user = add i32 %div.phi, 5
ret void
}
@@ -56,7 +110,8 @@ Y:
}
-; temporal-uniform use of a valud, definition and users are carried by a surrounding divergent loop
+; temporal-uniform use of a value, definition and users are carried by a
+; surrounding divergent loop
define amdgpu_kernel void @temporal_uniform_indivloop(i32 %n, i32 %a, i32 %b) #0 {
; CHECK-LABEL: for function 'temporal_uniform_indivloop':
; CHECK-NOT: DIVERGENT: %uni.
@@ -85,6 +140,7 @@ X:
Y:
%div.alsouser = add i32 %uni.inc, 5
ret void
+; CHECK: DIVERGENT: %div.alsouser =
}
@@ -113,6 +169,7 @@ X:
G:
%div.user = add i32 %uni.inc, 5
br i1 %uni.cond, label %G, label %Y
+; CHECK: DIVERGENT: %div.user =
Y:
ret void
@@ -127,10 +184,13 @@ define amdgpu_kernel void @temporal_diverge_loopuser_nested(i32 %n, i32 %a, i32
entry:
%tid = call i32 @llvm.amdgcn.workitem.id.x()
%uni.cond = icmp slt i32 %a, 0
+ br label %G
+
+G:
br label %H
H:
- %uni.merge.h = phi i32 [ 0, %entry ], [ %uni.inc, %H ]
+ %uni.merge.h = phi i32 [ 0, %G ], [ %uni.inc, %H ]
%uni.inc = add i32 %uni.merge.h, 1
%div.exitx = icmp slt i32 %tid, 0
br i1 %div.exitx, label %X, label %H ; divergent branch
@@ -138,17 +198,14 @@ H:
; CHECK: DIVERGENT: br i1 %div.exitx,
X:
- br label %G
-
-G:
+; CHECK: DIVERGENT: %div.user =
%div.user = add i32 %uni.inc, 5
- br i1 %uni.cond, label %G, label %Y
+ br i1 %uni.cond, label %X, label %G
Y:
ret void
}
-
declare i32 @llvm.amdgcn.workitem.id.x() #0
attributes #0 = { nounwind readnone }
More information about the llvm-commits
mailing list