[llvm] f8416c8 - [UniformityAnalysis] Replace DivergentValues with UniformValues for conservative divergence queries (#180509)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 27 00:27:22 PDT 2026
Author: Pankaj Dwivedi
Date: 2026-03-27T12:57:17+05:30
New Revision: f8416c8643534677ed89d0510bb5fd1b2a4e1abd
URL: https://github.com/llvm/llvm-project/commit/f8416c8643534677ed89d0510bb5fd1b2a4e1abd
DIFF: https://github.com/llvm/llvm-project/commit/f8416c8643534677ed89d0510bb5fd1b2a4e1abd.diff
LOG: [UniformityAnalysis] Replace DivergentValues with UniformValues for conservative divergence queries (#180509)
This patch replaces DivergentValues with UniformValues as the single
source of truth for tracking divergence in UniformityInfo.
Old model: DivergentValues starts empty; values are added as divergence
is propagated. isDivergent(V) returns DivergentValues.count(V).
New model: UniformValues starts fully populated (all
instructions/arguments for IR, all register defs for MIR) during
initialize(). Values are removed as divergence is propagated.
isDivergent(V) returns !UniformValues.contains(V), so any value not
present in the set (e.g., a newly created instruction that was not
present during analysis) is conservatively treated as divergent. This
avoids silent miscompilations when transformation passes introduce new
values and query their uniformity.
---------
Co-authored-by: padivedi <padivedi at amd.com>
Added:
llvm/unittests/Target/AMDGPU/UniformityAnalysisTest.cpp
Modified:
llvm/include/llvm/ADT/GenericSSAContext.h
llvm/include/llvm/ADT/GenericUniformityImpl.h
llvm/include/llvm/ADT/GenericUniformityInfo.h
llvm/lib/Analysis/UniformityAnalysis.cpp
llvm/lib/CodeGen/MachineSSAContext.cpp
llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
llvm/lib/IR/SSAContext.cpp
llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir
llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform.mir
llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/never-uniform.mir
llvm/unittests/Target/AMDGPU/CMakeLists.txt
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h
index 426a083778d6e..36d0e2e8f70cd 100644
--- a/llvm/include/llvm/ADT/GenericSSAContext.h
+++ b/llvm/include/llvm/ADT/GenericSSAContext.h
@@ -94,6 +94,12 @@ template <typename _FunctionT> class GenericSSAContext {
const BlockT &block);
static bool isConstantOrUndefValuePhi(const InstructionT &Instr);
+
+ /// Whether \p V is always uniform and will not be added to UniformValues.
+ /// For IR this identifies constants and globals; for MIR it returns false
+ /// (all registers are tracked).
+ static bool isAlwaysUniform(ConstValueRefT V);
+
const BlockT *getDefBlock(ConstValueRefT value) const;
Printable print(const BlockT *block) const;
diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h
index 6089f69b393bd..b6c714b704a57 100644
--- a/llvm/include/llvm/ADT/GenericUniformityImpl.h
+++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h
@@ -363,8 +363,9 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
/// \brief Examine \p I for divergent outputs and add to the worklist.
void markDivergent(const InstructionT &I);
- /// \brief Mark \p DivVal as a divergent value.
- /// \returns Whether the tracked divergence state of \p DivVal changed.
+ /// \brief Mark \p DivVal as a divergent value by removing it from
+ /// UniformValues. \returns Whether the tracked divergence state of
+ /// \p DivVal changed.
bool markDivergent(ConstValueRefT DivVal);
/// \brief Mark outputs of \p Instr as divergent.
@@ -375,9 +376,6 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
/// Divergence is seeded by calls to \p markDivergent.
void compute();
- /// \brief Whether any value was marked or analyzed to be divergent.
- bool hasDivergence() const { return !DivergentValues.empty(); }
-
/// \brief Whether \p Val will always return a uniform value regardless of its
/// operands
bool isAlwaysUniform(const InstructionT &Instr) const;
@@ -392,7 +390,20 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
};
/// \brief Whether \p Val is divergent at its definition.
- bool isDivergent(ConstValueRefT V) const { return DivergentValues.count(V); }
+ /// When the target has no branch divergence, compute() is never called
+ /// and everything is uniform. Otherwise, values not in UniformValues
+ /// (e.g. newly created) are conservatively treated as divergent.
+ bool isDivergent(ConstValueRefT V) const {
+ if (!HasBranchDivergence)
+ return false;
+ // Only values that were present during analysis are tracked in
+ // UniformValues (Instructions/Arguments for IR, Registers for MIR).
+ // Other values (e.g. constants, globals) are always uniform but are
+ // not added to UniformValues; this check avoids false divergence.
+ if (ContextT::isAlwaysUniform(V))
+ return false;
+ return !UniformValues.contains(V);
+ }
bool isDivergentUse(const UseT &U) const;
@@ -402,6 +413,10 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
void print(raw_ostream &out) const;
+ /// Print divergent arguments and return true if any were found.
+ /// IR specialization iterates F.args(); default is a no-op.
+ bool printDivergentArgs(raw_ostream &out) const;
+
SmallVector<TemporalDivergenceTuple, 8> TemporalDivergenceList;
void recordTemporalDivergence(ConstValueRefT, const InstructionT *,
@@ -420,10 +435,18 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
const CycleInfoT &CI;
const TargetTransformInfo *TTI = nullptr;
- // Detected/marked divergent values.
- DenseSet<ConstValueRefT> DivergentValues;
+ // Whether the target has branch divergence. Set at the start of compute(),
+ // which is only called when the target has branch divergence. When false,
+ // isDivergent() returns false for all values.
+ bool HasBranchDivergence = false;
+
SmallPtrSet<const BlockT *, 32> DivergentTermBlocks;
+ // Values known to be uniform. Populated in initialize() with all values,
+ // then values are removed as divergence is propagated. After analysis,
+ // values not in this set are conservatively treated as divergent.
+ DenseSet<ConstValueRefT> UniformValues;
+
// Internal worklist for divergence propagation.
std::vector<const InstructionT *> Worklist;
@@ -819,7 +842,7 @@ void GenericUniformityAnalysisImpl<ContextT>::markDivergent(
template <typename ContextT>
bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
ConstValueRefT Val) {
- if (DivergentValues.insert(Val).second) {
+ if (UniformValues.erase(Val)) {
LLVM_DEBUG(dbgs() << "marked divergent: " << Context.print(Val) << "\n");
return true;
}
@@ -1128,12 +1151,7 @@ void GenericUniformityAnalysisImpl<ContextT>::analyzeControlDivergence(
template <typename ContextT>
void GenericUniformityAnalysisImpl<ContextT>::compute() {
- // Initialize worklist.
- auto DivValuesCopy = DivergentValues;
- for (const auto DivVal : DivValuesCopy) {
- assert(isDivergent(DivVal) && "Worklist invariant violated!");
- pushUsers(DivVal);
- }
+ HasBranchDivergence = true;
// All values on the Worklist are divergent.
// Their users may not have been updated yet.
@@ -1167,6 +1185,12 @@ bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
return UniformOverrides.contains(&Instr);
}
+template <typename ContextT>
+bool GenericUniformityAnalysisImpl<ContextT>::printDivergentArgs(
+ raw_ostream &) const {
+ return false;
+}
+
template <typename ContextT>
GenericUniformityInfo<ContextT>::GenericUniformityInfo(
const DominatorTreeT &DT, const CycleInfoT &CI,
@@ -1176,35 +1200,18 @@ GenericUniformityInfo<ContextT>::GenericUniformityInfo(
template <typename ContextT>
void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const {
- bool haveDivergentArgs = false;
-
// When we print Value, LLVM IR instruction, we want to print extra new line.
// In LLVM IR print function for Value does not print new line at the end.
// In MIR print for MachineInstr prints new line at the end.
constexpr bool IsMIR = std::is_same<InstructionT, MachineInstr>::value;
std::string NewLine = IsMIR ? "" : "\n";
- // Control flow instructions may be divergent even if their inputs are
- // uniform. Thus, although exceedingly rare, it is possible to have a program
- // with no divergent values but with divergent control structures.
- if (DivergentValues.empty() && DivergentTermBlocks.empty() &&
- DivergentExitCycles.empty()) {
- OS << "ALL VALUES UNIFORM\n";
- return;
- }
+ bool FoundDivergence = false;
- for (const auto &entry : DivergentValues) {
- const BlockT *parent = Context.getDefBlock(entry);
- if (!parent) {
- if (!haveDivergentArgs) {
- OS << "DIVERGENT ARGUMENTS:\n";
- haveDivergentArgs = true;
- }
- OS << " DIVERGENT: " << Context.print(entry) << '\n';
- }
- }
+ FoundDivergence |= printDivergentArgs(OS);
if (!AssumedDivergent.empty()) {
+ FoundDivergence = true;
OS << "CYCLES ASSUMED DIVERGENT:\n";
for (const CycleT *cycle : AssumedDivergent) {
OS << " " << cycle->print(Context) << '\n';
@@ -1212,6 +1219,7 @@ void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const {
}
if (!DivergentExitCycles.empty()) {
+ FoundDivergence = true;
OS << "CYCLES WITH DIVERGENT EXIT:\n";
for (const CycleT *cycle : DivergentExitCycles) {
OS << " " << cycle->print(Context) << '\n';
@@ -1219,6 +1227,7 @@ void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const {
}
if (!TemporalDivergenceList.empty()) {
+ FoundDivergence = true;
OS << "\nTEMPORAL DIVERGENCE LIST:\n";
for (auto [Val, UseInst, Cycle] : TemporalDivergenceList) {
@@ -1235,10 +1244,12 @@ void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const {
SmallVector<ConstValueRefT, 16> defs;
Context.appendBlockDefs(defs, block);
for (auto value : defs) {
- if (isDivergent(value))
+ if (isDivergent(value)) {
+ FoundDivergence = true;
OS << " DIVERGENT: ";
- else
+ } else {
OS << " ";
+ }
OS << Context.print(value) << NewLine;
}
@@ -1246,6 +1257,8 @@ void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const {
SmallVector<const InstructionT *, 8> terms;
Context.appendBlockTerms(terms, block);
bool divergentTerminators = hasDivergentTerminator(block);
+ if (divergentTerminators)
+ FoundDivergence = true;
for (auto *T : terms) {
if (divergentTerminators)
OS << " DIVERGENT: ";
@@ -1256,6 +1269,9 @@ void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const {
OS << "END BLOCK\n";
}
+
+ if (!FoundDivergence)
+ OS << "ALL VALUES UNIFORM\n";
}
template <typename ContextT>
@@ -1266,11 +1282,6 @@ GenericUniformityInfo<ContextT>::getTemporalDivergenceList() const {
DA->TemporalDivergenceList.end());
}
-template <typename ContextT>
-bool GenericUniformityInfo<ContextT>::hasDivergence() const {
- return DA->hasDivergence();
-}
-
template <typename ContextT>
const typename ContextT::FunctionT &
GenericUniformityInfo<ContextT>::getFunction() const {
diff --git a/llvm/include/llvm/ADT/GenericUniformityInfo.h b/llvm/include/llvm/ADT/GenericUniformityInfo.h
index be948707c3af8..a504335ec078e 100644
--- a/llvm/include/llvm/ADT/GenericUniformityInfo.h
+++ b/llvm/include/llvm/ADT/GenericUniformityInfo.h
@@ -54,9 +54,6 @@ template <typename ContextT> class GenericUniformityInfo {
DA->compute();
}
- /// Whether any divergence was detected.
- bool hasDivergence() const;
-
/// The GPU kernel this analysis result is for
const FunctionT &getFunction() const;
diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp
index 805f4894f3634..f40ea1a556f9f 100644
--- a/llvm/lib/Analysis/UniformityAnalysis.cpp
+++ b/llvm/lib/Analysis/UniformityAnalysis.cpp
@@ -30,30 +30,6 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
return markDivergent(cast<Value>(&Instr));
}
-template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
- for (auto &I : instructions(F)) {
- InstructionUniformity IU = TTI->getInstructionUniformity(&I);
- switch (IU) {
- case InstructionUniformity::AlwaysUniform:
- addUniformOverride(I);
- break;
- case InstructionUniformity::NeverUniform:
- markDivergent(I);
- break;
- case InstructionUniformity::Custom:
- addCustomUniformityCandidate(&I);
- break;
- case InstructionUniformity::Default:
- break;
- }
- }
- for (auto &Arg : F.args()) {
- if (TTI->getInstructionUniformity(&Arg) ==
- InstructionUniformity::NeverUniform)
- markDivergent(&Arg);
- }
-}
-
template <>
void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
const Value *V) {
@@ -73,6 +49,65 @@ void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
pushUsers(cast<Value>(&Instr));
}
+template <>
+bool llvm::GenericUniformityAnalysisImpl<SSAContext>::printDivergentArgs(
+ raw_ostream &OS) const {
+ bool haveDivergentArgs = false;
+ for (const auto &Arg : F.args()) {
+ if (isDivergent(&Arg)) {
+ if (!haveDivergentArgs) {
+ OS << "DIVERGENT ARGUMENTS:\n";
+ haveDivergentArgs = true;
+ }
+ OS << " DIVERGENT: " << Context.print(&Arg) << '\n';
+ }
+ }
+ return haveDivergentArgs;
+}
+
+template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
+ // Pre-populate UniformValues with uniform values, then seed divergence.
+ // NeverUniform values are not inserted -- they are divergent by definition
+ // and will be reported as such by isDivergent() (not in UniformValues).
+ SmallVector<const Value *, 4> DivergentArgs;
+ for (auto &Arg : F.args()) {
+ if (TTI->getInstructionUniformity(&Arg) ==
+ InstructionUniformity::NeverUniform)
+ DivergentArgs.push_back(&Arg);
+ else
+ UniformValues.insert(&Arg);
+ }
+ for (auto &I : instructions(F)) {
+ InstructionUniformity IU = TTI->getInstructionUniformity(&I);
+ switch (IU) {
+ case InstructionUniformity::AlwaysUniform:
+ UniformValues.insert(&I);
+ addUniformOverride(I);
+ continue;
+ case InstructionUniformity::NeverUniform:
+ // Skip inserting -- divergent by definition. Add to Worklist directly
+ // so compute() propagates divergence to users.
+ if (I.isTerminator())
+ DivergentTermBlocks.insert(I.getParent());
+ Worklist.push_back(&I);
+ continue;
+ case InstructionUniformity::Custom:
+ UniformValues.insert(&I);
+ addCustomUniformityCandidate(&I);
+ continue;
+ case InstructionUniformity::Default:
+ UniformValues.insert(&I);
+ break;
+ }
+ }
+ // Arguments are not instructions and cannot go on the Worklist, so we
+ // propagate their divergence to users explicitly here. This must happen
+ // after all instructions are in UniformValues so markDivergent (called
+ // inside pushUsers) can successfully erase user instructions from the set.
+ for (const Value *Arg : DivergentArgs)
+ pushUsers(Arg);
+}
+
template <>
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
const Instruction &I, const Cycle &DefCycle) const {
diff --git a/llvm/lib/CodeGen/MachineSSAContext.cpp b/llvm/lib/CodeGen/MachineSSAContext.cpp
index bbbfb3ce2788d..77f6771a62e66 100644
--- a/llvm/lib/CodeGen/MachineSSAContext.cpp
+++ b/llvm/lib/CodeGen/MachineSSAContext.cpp
@@ -59,6 +59,8 @@ static bool isUndef(const MachineInstr &MI) {
MI.getOpcode() == TargetOpcode::IMPLICIT_DEF;
}
+template <> bool MachineSSAContext::isAlwaysUniform(Register) { return false; }
+
/// MachineInstr equivalent of PHINode::hasConstantOrUndefValue() for G_PHI.
template <>
bool MachineSSAContext::isConstantOrUndefValuePhi(const MachineInstr &Phi) {
diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
index af1c448497a52..416864d07f7d0 100644
--- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
+++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
@@ -48,6 +48,20 @@ bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
template <>
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
+ // Pre-populate UniformValues with all register defs. Physical register defs
+ // are included because they are never analyzed for divergence (initialize
+ // and markDefsDivergent skip them), so they must be in UniformValues to
+ // avoid being falsely reported as divergent.
+ for (const MachineBasicBlock &BB : F) {
+ for (const MachineInstr &MI : BB.instrs()) {
+ for (const MachineOperand &Op : MI.all_defs()) {
+ Register Reg = Op.getReg();
+ if (Reg)
+ UniformValues.insert(Reg);
+ }
+ }
+ }
+
const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
for (const MachineBasicBlock &block : F) {
diff --git a/llvm/lib/IR/SSAContext.cpp b/llvm/lib/IR/SSAContext.cpp
index 20b6ea1e972d4..feb8570d32757 100644
--- a/llvm/lib/IR/SSAContext.cpp
+++ b/llvm/lib/IR/SSAContext.cpp
@@ -68,6 +68,10 @@ bool SSAContext::isConstantOrUndefValuePhi(const Instruction &Instr) {
return false;
}
+template <> bool SSAContext::isAlwaysUniform(const Value *V) {
+ return !isa<Instruction>(V) && !isa<Argument>(V);
+}
+
template <> Intrinsic::ID SSAContext::getIntrinsicID(const Instruction &I) {
if (auto *CB = dyn_cast<CallBase>(&I))
return CB->getIntrinsicID();
diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir
index 198dd9d2fe553..8f44855b3caff 100644
--- a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir
+++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir
@@ -21,7 +21,7 @@ body: |
bb.1:
liveins: $sgpr4_sgpr5
; CHECK-LABEL: MachineUniformityInfo for function: @icmp
- ; CHECK-NEXT: ALL VALUES UNIFORM
+ ; CHECK: ALL VALUES UNIFORM
%3:_(p4) = COPY $sgpr4_sgpr5
%13:_(s32) = G_CONSTANT i32 0
@@ -41,7 +41,7 @@ body: |
bb.1:
liveins: $sgpr4_sgpr5
; CHECK-LABEL: MachineUniformityInfo for function: @fcmp
- ; CHECK-NEXT: ALL VALUES UNIFORM
+ ; CHECK: ALL VALUES UNIFORM
%3:_(p4) = COPY $sgpr4_sgpr5
%10:_(s32) = G_CONSTANT i32 0
@@ -64,7 +64,7 @@ body: |
bb.1:
liveins: $sgpr4_sgpr5
; CHECK-LABEL: MachineUniformityInfo for function: @ballot
- ; CHECK-NEXT: ALL VALUES UNIFORM
+ ; CHECK: ALL VALUES UNIFORM
%2:_(p4) = COPY $sgpr4_sgpr5
%10:_(p1) = G_IMPLICIT_DEF
diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform.mir b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform.mir
index f5161bbddc795..09d3cfb9aa6cd 100644
--- a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform.mir
+++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform.mir
@@ -10,7 +10,7 @@ machineFunctionInfo:
body: |
bb.0:
; CHECK-LABEL: MachineUniformityInfo for function: @readlane
- ; CHECK-NEXT: ALL VALUES UNIFORM
+ ; CHECK: ALL VALUES UNIFORM
%0:vgpr_32 = IMPLICIT_DEF
%1:vgpr_32 = IMPLICIT_DEF
%2:sreg_32_xm0 = V_READFIRSTLANE_B32 %0, implicit $exec
@@ -28,7 +28,7 @@ machineFunctionInfo:
body: |
bb.0:
; CHECK-LABEL: MachineUniformityInfo for function: @readlane2
- ; CHECK-NEXT: ALL VALUES UNIFORM
+ ; CHECK: ALL VALUES UNIFORM
%0:vgpr_32 = IMPLICIT_DEF
%1:vgpr_32 = IMPLICIT_DEF
%4:sgpr_32 = V_READLANE_B32 $vgpr0, 0, implicit $exec
@@ -49,7 +49,7 @@ machineFunctionInfo:
body: |
bb.0:
; CHECK-LABEL: MachineUniformityInfo for function: @sgprcopy
- ; CHECK-NEXT: ALL VALUES UNIFORM
+ ; CHECK: ALL VALUES UNIFORM
liveins: $sgpr0,$sgpr1,$vgpr0
%0:sgpr_32 = COPY $sgpr0
%1:vgpr_32 = COPY $sgpr1
diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/never-uniform.mir b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/never-uniform.mir
index 3ef1c46dd1b2d..96342373a3637 100644
--- a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/never-uniform.mir
+++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/never-uniform.mir
@@ -65,7 +65,7 @@ machineFunctionInfo:
body: |
bb.0:
; CHECK-LABEL: MachineUniformityInfo for function: @dsreads
- ; CHECK-NEXT: ALL VALUES UNIFORM
+ ; CHECK: ALL VALUES UNIFORM
%0:vreg_64 = IMPLICIT_DEF
$m0 = S_MOV_B32 0
%1:vgpr_32 = DS_READ_ADDTID_B32 0, 0, implicit $m0, implicit $exec
diff --git a/llvm/unittests/Target/AMDGPU/CMakeLists.txt b/llvm/unittests/Target/AMDGPU/CMakeLists.txt
index d6cbaf3f3fb5d..7760f694933c2 100644
--- a/llvm/unittests/Target/AMDGPU/CMakeLists.txt
+++ b/llvm/unittests/Target/AMDGPU/CMakeLists.txt
@@ -8,6 +8,8 @@ set(LLVM_LINK_COMPONENTS
AMDGPUDesc
AMDGPUInfo
AMDGPUUtils
+ Analysis
+ AsmParser
CodeGen
CodeGenTypes
Core
@@ -25,4 +27,5 @@ add_llvm_target_unittest(AMDGPUTests
ExecMayBeModifiedBeforeAnyUse.cpp
LiveRegUnits.cpp
PALMetadata.cpp
+ UniformityAnalysisTest.cpp
)
diff --git a/llvm/unittests/Target/AMDGPU/UniformityAnalysisTest.cpp b/llvm/unittests/Target/AMDGPU/UniformityAnalysisTest.cpp
new file mode 100644
index 0000000000000..ae44d3ef6cbf2
--- /dev/null
+++ b/llvm/unittests/Target/AMDGPU/UniformityAnalysisTest.cpp
@@ -0,0 +1,95 @@
+//===- UniformityAnalysisTest.cpp - Conservative divergence query test ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Tests that values created after Uniformity analysis are conservatively
+// reported as divergent, since they are not present in UniformValues.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/UniformityAnalysis.h"
+#include "llvm/ADT/GenericUniformityImpl.h"
+#include "llvm/Analysis/CycleAnalysis.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/TargetParser/Triple.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+static std::unique_ptr<TargetMachine>
+createAMDGPUTargetMachine(std::string TStr, StringRef CPU, StringRef FS) {
+ Triple TT(TStr);
+ std::string Error;
+ const Target *T = TargetRegistry::lookupTarget(TT, Error);
+ if (!T)
+ return nullptr;
+ return std::unique_ptr<TargetMachine>(
+ T->createTargetMachine(TT, CPU, FS, {}, std::nullopt));
+}
+
+static UniformityInfo computeUniformity(const TargetTransformInfo *TTI,
+ Function *F) {
+ DominatorTree DT(*F);
+ CycleInfo CI;
+ CI.compute(*F);
+ UniformityInfo UI(DT, CI, TTI);
+ if (TTI->hasBranchDivergence(F))
+ UI.compute();
+ return UI;
+}
+
+TEST(UniformityAnalysis, NewValueIsConservativelyDivergent) {
+ LLVMInitializeAMDGPUTargetInfo();
+ LLVMInitializeAMDGPUTarget();
+ LLVMInitializeAMDGPUTargetMC();
+
+ StringRef ModuleString = R"(
+ target triple = "amdgcn-unknown-amdhsa"
+ define amdgpu_kernel void @test(i32 inreg %a, i32 inreg %b) {
+ %add = add i32 %a, %b
+ ret void
+ }
+ )";
+ LLVMContext Context;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M = parseAssemblyString(ModuleString, Err, Context);
+ ASSERT_TRUE(M) << Err.getMessage();
+
+ Function *F = M->getFunction("test");
+ ASSERT_TRUE(F);
+
+ auto TM =
+ createAMDGPUTargetMachine("amdgcn-amd-", "gfx1010", "+wavefrontsize32");
+ ASSERT_TRUE(TM);
+ TargetTransformInfo TTI = TM->getTargetTransformInfo(*F);
+
+ UniformityInfo UI = computeUniformity(&TTI, F);
+
+ // Existing values from the analysis are uniform (kernel args are inreg).
+ Instruction *AddInst = &*F->getEntryBlock().begin();
+ ASSERT_TRUE(isa<BinaryOperator>(AddInst));
+ EXPECT_FALSE(UI.isDivergent(AddInst)) << "%add should be uniform";
+ EXPECT_FALSE(UI.isDivergent(F->getArg(0))) << "%a should be uniform";
+ EXPECT_FALSE(UI.isDivergent(F->getArg(1))) << "%b should be uniform";
+
+ // Create a new instruction after analysis. It was not present during
+ // analysis, so it is not in UniformValues and must be conservatively
+ // reported as divergent.
+ IRBuilder<> Builder(AddInst->getNextNode());
+ Value *NewInst = Builder.CreateMul(F->getArg(0), F->getArg(1), "new_mul");
+
+ EXPECT_TRUE(UI.isDivergent(NewInst))
+ << "New instruction created after analysis must be reported divergent";
+}
More information about the llvm-commits
mailing list