[llvm] d6b04f3 - [SDag] Refactor and simplify divergence calculation and checking. NFC.
Jay Foad via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 29 06:12:22 PDT 2020
Author: Jay Foad
Date: 2020-09-29T14:05:07+01:00
New Revision: d6b04f3937e374572039005d1446b4a950dc8f01
URL: https://github.com/llvm/llvm-project/commit/d6b04f3937e374572039005d1446b4a950dc8f01
DIFF: https://github.com/llvm/llvm-project/commit/d6b04f3937e374572039005d1446b4a950dc8f01.diff
LOG: [SDag] Refactor and simplify divergence calculation and checking. NFC.
Added:
Modified:
llvm/include/llvm/CodeGen/SelectionDAG.h
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 6e733f8c9b9c..f86d46da23ce 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1424,6 +1424,9 @@ class SelectionDAG {
void setNodeMemRefs(MachineSDNode *N,
ArrayRef<MachineMemOperand *> NewMemRefs);
+ // Calculate divergence of node \p N based on its operands.
+ bool calculateDivergence(SDNode *N);
+
// Propagates the change in divergence to users
void updateDivergence(SDNode * N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3e3d79871162..cfb4aa2f0bb5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8718,21 +8718,31 @@ namespace {
} // end anonymous namespace
-void SelectionDAG::updateDivergence(SDNode * N)
-{
- if (TLI->isSDNodeAlwaysUniform(N))
- return;
- bool IsDivergent = TLI->isSDNodeSourceOfDivergence(N, FLI, DA);
+bool SelectionDAG::calculateDivergence(SDNode *N) {
+ if (TLI->isSDNodeAlwaysUniform(N)) {
+ assert(!TLI->isSDNodeSourceOfDivergence(N, FLI, DA) &&
+ "Conflicting divergence information!");
+ return false;
+ }
+ if (TLI->isSDNodeSourceOfDivergence(N, FLI, DA))
+ return true;
for (auto &Op : N->ops()) {
- if (Op.Val.getValueType() != MVT::Other)
- IsDivergent |= Op.getNode()->isDivergent();
+ if (Op.Val.getValueType() != MVT::Other && Op.getNode()->isDivergent())
+ return true;
}
- if (N->SDNodeBits.IsDivergent != IsDivergent) {
- N->SDNodeBits.IsDivergent = IsDivergent;
- for (auto U : N->uses()) {
- updateDivergence(U);
+ return false;
+}
+
+void SelectionDAG::updateDivergence(SDNode *N) {
+ SmallVector<SDNode *, 16> Worklist(1, N);
+ do {
+ N = Worklist.pop_back_val();
+ bool IsDivergent = calculateDivergence(N);
+ if (N->SDNodeBits.IsDivergent != IsDivergent) {
+ N->SDNodeBits.IsDivergent = IsDivergent;
+ Worklist.insert(Worklist.end(), N->use_begin(), N->use_end());
}
- }
+ } while (!Worklist.empty());
}
void SelectionDAG::CreateTopologicalOrder(std::vector<SDNode *> &Order) {
@@ -8758,26 +8768,9 @@ void SelectionDAG::CreateTopologicalOrder(std::vector<SDNode *> &Order) {
void SelectionDAG::VerifyDAGDiverence() {
std::vector<SDNode *> TopoOrder;
CreateTopologicalOrder(TopoOrder);
- const TargetLowering &TLI = getTargetLoweringInfo();
- DenseMap<const SDNode *, bool> DivergenceMap;
- for (auto &N : allnodes()) {
- DivergenceMap[&N] = false;
- }
- for (auto N : TopoOrder) {
- bool IsDivergent = DivergenceMap[N];
- bool IsSDNodeDivergent = TLI.isSDNodeSourceOfDivergence(N, FLI, DA);
- for (auto &Op : N->ops()) {
- if (Op.Val.getValueType() != MVT::Other)
- IsSDNodeDivergent |= DivergenceMap[Op.getNode()];
- }
- if (!IsDivergent && IsSDNodeDivergent && !TLI.isSDNodeAlwaysUniform(N)) {
- DivergenceMap[N] = true;
- }
- }
- for (auto &N : allnodes()) {
- (void)N;
- assert(DivergenceMap[&N] == N.isDivergent() &&
- "Divergence bit inconsistency detected\n");
+ for (auto *N : TopoOrder) {
+ assert(calculateDivergence(N) == N->isDivergent() &&
+ "Divergence bit inconsistency detected");
}
}
#endif
@@ -9963,13 +9956,14 @@ void SelectionDAG::createOperands(SDNode *Node, ArrayRef<SDValue> Vals) {
Ops[I].setUser(Node);
Ops[I].setInitial(Vals[I]);
if (Ops[I].Val.getValueType() != MVT::Other) // Skip Chain. It does not carry divergence.
- IsDivergent = IsDivergent || Ops[I].getNode()->isDivergent();
+ IsDivergent |= Ops[I].getNode()->isDivergent();
}
Node->NumOperands = Vals.size();
Node->OperandList = Ops;
- IsDivergent |= TLI->isSDNodeSourceOfDivergence(Node, FLI, DA);
- if (!TLI->isSDNodeAlwaysUniform(Node))
+ if (!TLI->isSDNodeAlwaysUniform(Node)) {
+ IsDivergent |= TLI->isSDNodeSourceOfDivergence(Node, FLI, DA);
Node->SDNodeBits.IsDivergent = IsDivergent;
+ }
checkForCycles(Node);
}
More information about the llvm-commits
mailing list