[llvm] [SelectionDAG][X86] Preserve unpredictable metadata for conditional branches in SelectionDAG, as well as JCCs generated by X86 backend. (PR #102101)
Tianqing Wang via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 7 22:53:26 PDT 2024
https://github.com/tianqingw updated https://github.com/llvm/llvm-project/pull/102101
>From da16934ccaed7ed54db72f62691c3ac6dbb7fdfd Mon Sep 17 00:00:00 2001
From: Tianqing Wang <tianqing.wang at intel.com>
Date: Tue, 6 Aug 2024 13:55:54 +0800
Subject: [PATCH] [SelectionDAG][X86] Preserve unpredictable metadata for
conditional branches in SelectionDAG, as well as JCCs generated by X86
backend.
This builds on 09515f2c2, which preserves unpredictable metadata in
CodeGen for `select`. This patch does it for conditional branches.
---
llvm/include/llvm/CodeGen/SelectionDAG.h | 5 ++++
.../llvm/CodeGen/SwitchLoweringUtils.h | 15 +++++++----
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 7 ++---
.../lib/CodeGen/SelectionDAG/InstrEmitter.cpp | 12 ++++-----
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 26 ++++++++++++++++---
.../SelectionDAG/SelectionDAGBuilder.cpp | 14 ++++++----
llvm/lib/Target/X86/X86ISelLowering.cpp | 22 ++++++++--------
llvm/test/CodeGen/X86/unpredictable-brcond.ll | 10 +++----
8 files changed, 72 insertions(+), 39 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 1d0124ec755352..54561c2d1e3629 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1165,8 +1165,13 @@ class SelectionDAG {
SDValue N2, SDValue N3, const SDNodeFlags Flags);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
SDValue N2, SDValue N3, SDValue N4);
+ SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
+ SDValue N2, SDValue N3, SDValue N4, const SDNodeFlags Flags);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
SDValue N2, SDValue N3, SDValue N4, SDValue N5);
+ SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
+ SDValue N2, SDValue N3, SDValue N4, SDValue N5,
+ const SDNodeFlags Flags);
// Specialize again based on number of operands for nodes with a VTList
// rather than a single VT.
diff --git a/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h b/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
index 99478e9f39e226..9282c4a771afb2 100644
--- a/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
+++ b/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
@@ -137,18 +137,21 @@ struct CaseBlock {
SDLoc DL;
DebugLoc DbgLoc;
- // Branch weights.
+ // Branch weights and predictability.
BranchProbability TrueProb, FalseProb;
+ bool IsUnpredictable;
// Constructor for SelectionDAG.
CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
const Value *cmpmiddle, MachineBasicBlock *truebb,
MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
BranchProbability trueprob = BranchProbability::getUnknown(),
- BranchProbability falseprob = BranchProbability::getUnknown())
+ BranchProbability falseprob = BranchProbability::getUnknown(),
+ bool isunpredictable = false)
: CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
- TrueProb(trueprob), FalseProb(falseprob) {}
+ TrueProb(trueprob), FalseProb(falseprob),
+ IsUnpredictable(isunpredictable) {}
// Constructor for GISel.
CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs,
@@ -156,10 +159,12 @@ struct CaseBlock {
MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
MachineBasicBlock *me, DebugLoc dl,
BranchProbability trueprob = BranchProbability::getUnknown(),
- BranchProbability falseprob = BranchProbability::getUnknown())
+ BranchProbability falseprob = BranchProbability::getUnknown(),
+ bool isunpredictable = false)
: PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle),
CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
- DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {}
+ DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob),
+ IsUnpredictable(isunpredictable) {}
};
struct JumpTable {
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a5b7397253e08d..9b96059ca674e7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -18164,7 +18164,7 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) {
// nondeterministic jumps).
if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
- N1->getOperand(0), N2);
+ N1->getOperand(0), N2, N->getFlags());
}
// Variant of the previous fold where there is a SETCC in between:
@@ -18213,7 +18213,8 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) {
if (Updated)
return DAG.getNode(
ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
- DAG.getSetCC(SDLoc(N1), N1->getValueType(0), S0, S1, Cond), N2);
+ DAG.getSetCC(SDLoc(N1), N1->getValueType(0), S0, S1, Cond), N2,
+ N->getFlags());
}
// If N is a constant we could fold this into a fallthrough or unconditional
@@ -18238,7 +18239,7 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) {
HandleSDNode ChainHandle(Chain);
if (SDValue NewN1 = rebuildSetCC(N1))
return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
- ChainHandle.getValue(), NewN1, N2);
+ ChainHandle.getValue(), NewN1, N2, N->getFlags());
}
return SDValue();
diff --git a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
index 4ce92e156cf85e..db33d5242601e0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
@@ -1065,14 +1065,17 @@ EmitMachineNode(SDNode *Node, bool IsClone, bool IsCloned,
// Create the new machine instruction.
MachineInstrBuilder MIB = BuildMI(*MF, Node->getDebugLoc(), II);
+ // Transfer IR flags from the SDNode to the MachineInstr
+ MachineInstr *MI = MIB.getInstr();
+ const SDNodeFlags Flags = Node->getFlags();
+ if (Flags.hasUnpredictable())
+ MI->setFlag(MachineInstr::MIFlag::Unpredictable);
+
// Add result register values for things that are defined by this
// instruction.
if (NumResults) {
CreateVirtualRegisters(Node, MIB, II, IsClone, IsCloned, VRBaseMap);
- // Transfer any IR flags from the SDNode to the MachineInstr
- MachineInstr *MI = MIB.getInstr();
- const SDNodeFlags Flags = Node->getFlags();
if (Flags.hasNoSignedZeros())
MI->setFlag(MachineInstr::MIFlag::FmNsz);
@@ -1105,9 +1108,6 @@ EmitMachineNode(SDNode *Node, bool IsClone, bool IsCloned,
if (Flags.hasNoFPExcept())
MI->setFlag(MachineInstr::MIFlag::NoFPExcept);
-
- if (Flags.hasUnpredictable())
- MI->setFlag(MachineInstr::MIFlag::Unpredictable);
}
// Emit all of the actual operands of this instruction, adding them to the
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 5c6a2454d664ba..0a35e89afac249 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -7613,16 +7613,34 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
}
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
- SDValue N1, SDValue N2, SDValue N3, SDValue N4) {
+ SDValue N1, SDValue N2, SDValue N3, SDValue N4,
+ const SDNodeFlags Flags) {
SDValue Ops[] = { N1, N2, N3, N4 };
- return getNode(Opcode, DL, VT, Ops);
+ return getNode(Opcode, DL, VT, Ops, Flags);
+}
+
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+ SDValue N1, SDValue N2, SDValue N3, SDValue N4) {
+ SDNodeFlags Flags;
+ if (Inserter)
+ Flags = Inserter->getFlags();
+ return getNode(Opcode, DL, VT, N1, N2, N3, N4, Flags);
}
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
SDValue N1, SDValue N2, SDValue N3, SDValue N4,
- SDValue N5) {
+ SDValue N5, const SDNodeFlags Flags) {
SDValue Ops[] = { N1, N2, N3, N4, N5 };
- return getNode(Opcode, DL, VT, Ops);
+ return getNode(Opcode, DL, VT, Ops, Flags);
+}
+
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+ SDValue N1, SDValue N2, SDValue N3, SDValue N4,
+ SDValue N5) {
+ SDNodeFlags Flags;
+ if (Inserter)
+ Flags = Inserter->getFlags();
+ return getNode(Opcode, DL, VT, N1, N2, N3, N4, N5, Flags);
}
/// getStackArgumentTokenFactor - Compute a TokenFactor to force all
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1f4436fb3a4966..2dfa5c2e767a2f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -2818,9 +2818,10 @@ void SelectionDAGBuilder::visitBr(const BranchInst &I) {
// je foo
// cmp D, E
// jle foo
+ bool IsUnpredictable = I.hasMetadata(LLVMContext::MD_unpredictable);
const Instruction *BOp = dyn_cast<Instruction>(CondVal);
if (!DAG.getTargetLoweringInfo().isJumpExpensive() && BOp &&
- BOp->hasOneUse() && !I.hasMetadata(LLVMContext::MD_unpredictable)) {
+ BOp->hasOneUse() && !IsUnpredictable) {
Value *Vec;
const Value *BOp0, *BOp1;
Instruction::BinaryOps Opcode = (Instruction::BinaryOps)0;
@@ -2869,7 +2870,9 @@ void SelectionDAGBuilder::visitBr(const BranchInst &I) {
// Create a CaseBlock record representing this branch.
CaseBlock CB(ISD::SETEQ, CondVal, ConstantInt::getTrue(*DAG.getContext()),
- nullptr, Succ0MBB, Succ1MBB, BrMBB, getCurSDLoc());
+ nullptr, Succ0MBB, Succ1MBB, BrMBB, getCurSDLoc(),
+ BranchProbability::getUnknown(), BranchProbability::getUnknown(),
+ IsUnpredictable);
// Use visitSwitchCase to actually insert the fast branch sequence for this
// cond branch.
@@ -2957,9 +2960,10 @@ void SelectionDAGBuilder::visitSwitchCase(CaseBlock &CB,
Cond = DAG.getNode(ISD::XOR, dl, Cond.getValueType(), Cond, True);
}
- SDValue BrCond = DAG.getNode(ISD::BRCOND, dl,
- MVT::Other, getControlRoot(), Cond,
- DAG.getBasicBlock(CB.TrueBB));
+ SDNodeFlags Flags;
+ Flags.setUnpredictable(CB.IsUnpredictable);
+ SDValue BrCond = DAG.getNode(ISD::BRCOND, dl, MVT::Other, getControlRoot(),
+ Cond, DAG.getBasicBlock(CB.TrueBB), Flags);
setValue(CurInst, BrCond);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 2891e21be1b267..3fc632c37fc636 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -24827,14 +24827,14 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
SDValue CCVal = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
- Overflow);
+ Overflow, Op->getFlags());
}
if (LHS.getSimpleValueType().isInteger()) {
SDValue CCVal;
SDValue EFLAGS = emitFlagsForSetcc(LHS, RHS, CC, SDLoc(Cond), DAG, CCVal);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
- EFLAGS);
+ EFLAGS, Op->getFlags());
}
if (CC == ISD::SETOEQ) {
@@ -24860,10 +24860,10 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
DAG.getNode(X86ISD::FCMP, SDLoc(Cond), MVT::i32, LHS, RHS);
SDValue CCVal = DAG.getTargetConstant(X86::COND_NE, dl, MVT::i8);
Chain = DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest,
- CCVal, Cmp);
+ CCVal, Cmp, Op->getFlags());
CCVal = DAG.getTargetConstant(X86::COND_P, dl, MVT::i8);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
- Cmp);
+ Cmp, Op->getFlags());
}
}
} else if (CC == ISD::SETUNE) {
@@ -24872,18 +24872,18 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
// separate test.
SDValue Cmp = DAG.getNode(X86ISD::FCMP, SDLoc(Cond), MVT::i32, LHS, RHS);
SDValue CCVal = DAG.getTargetConstant(X86::COND_NE, dl, MVT::i8);
- Chain =
- DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal, Cmp);
+ Chain = DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
+ Cmp, Op->getFlags());
CCVal = DAG.getTargetConstant(X86::COND_P, dl, MVT::i8);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
- Cmp);
+ Cmp, Op->getFlags());
} else {
X86::CondCode X86Cond =
TranslateX86CC(CC, dl, /*IsFP*/ true, LHS, RHS, DAG);
SDValue Cmp = DAG.getNode(X86ISD::FCMP, SDLoc(Cond), MVT::i32, LHS, RHS);
SDValue CCVal = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
- Cmp);
+ Cmp, Op->getFlags());
}
}
@@ -24894,7 +24894,7 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
SDValue CCVal = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
- Overflow);
+ Overflow, Op->getFlags());
}
// Look past the truncate if the high bits are known zero.
@@ -24913,8 +24913,8 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
SDValue CCVal;
SDValue EFLAGS = emitFlagsForSetcc(LHS, RHS, ISD::SETNE, dl, DAG, CCVal);
- return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
- EFLAGS);
+ return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal, EFLAGS,
+ Op->getFlags());
}
// Lower dynamic stack allocation to _alloca call for Cygwin/Mingw targets.
diff --git a/llvm/test/CodeGen/X86/unpredictable-brcond.ll b/llvm/test/CodeGen/X86/unpredictable-brcond.ll
index 6c894ea8277679..12411f1c49f2db 100644
--- a/llvm/test/CodeGen/X86/unpredictable-brcond.ll
+++ b/llvm/test/CodeGen/X86/unpredictable-brcond.ll
@@ -1,5 +1,5 @@
; NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 5
-; Currently, unpredictable metadata on conditional branches is lost during CodeGen.
+; Make sure MIR generated for conditional branch with unpredictable metadata has unpredictable flag.
; RUN: llc -mtriple=x86_64-unknown-linux-gnu -stop-after=finalize-isel < %s | FileCheck %s
define void @cond_branch_1(i1 %cond) {
@@ -11,7 +11,7 @@ define void @cond_branch_1(i1 %cond) {
; CHECK-NEXT: [[COPY:%[0-9]+]]:gr32 = COPY $edi
; CHECK-NEXT: [[COPY1:%[0-9]+]]:gr8 = COPY [[COPY]].sub_8bit
; CHECK-NEXT: TEST8ri killed [[COPY1]], 1, implicit-def $eflags
- ; CHECK-NEXT: JCC_1 %bb.2, 4, implicit $eflags
+ ; CHECK-NEXT: unpredictable JCC_1 %bb.2, 4, implicit $eflags
; CHECK-NEXT: JMP_1 %bb.1
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: bb.1.then:
@@ -51,7 +51,7 @@ define void @cond_branch_2(double %a, double %b, i32 %c, i32 %d) nounwind {
; CHECK-NEXT: [[SETCCr1:%[0-9]+]]:gr8 = SETCCr 6, implicit $eflags
; CHECK-NEXT: [[OR8rr:%[0-9]+]]:gr8 = OR8rr [[SETCCr]], killed [[SETCCr1]], implicit-def dead $eflags
; CHECK-NEXT: TEST8rr [[OR8rr]], [[OR8rr]], implicit-def $eflags
- ; CHECK-NEXT: JCC_1 %bb.2, 5, implicit $eflags
+ ; CHECK-NEXT: unpredictable JCC_1 %bb.2, 5, implicit $eflags
; CHECK-NEXT: JMP_1 %bb.1
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: bb.1.true:
@@ -89,8 +89,8 @@ define void @isint_branch(double %d) nounwind {
; CHECK-NEXT: [[CVTDQ2PDrr:%[0-9]+]]:vr128 = CVTDQ2PDrr killed [[CVTTPD2DQrr]]
; CHECK-NEXT: [[COPY2:%[0-9]+]]:fr64 = COPY [[CVTDQ2PDrr]]
; CHECK-NEXT: nofpexcept UCOMISDrr [[COPY]], killed [[COPY2]], implicit-def $eflags, implicit $mxcsr
- ; CHECK-NEXT: JCC_1 %bb.2, 5, implicit $eflags
- ; CHECK-NEXT: JCC_1 %bb.2, 10, implicit $eflags
+ ; CHECK-NEXT: unpredictable JCC_1 %bb.2, 5, implicit $eflags
+ ; CHECK-NEXT: unpredictable JCC_1 %bb.2, 10, implicit $eflags
; CHECK-NEXT: JMP_1 %bb.1
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: bb.1.true:
More information about the llvm-commits
mailing list