[llvm] [SelectionDAG][X86] Preserve unpredictable metadata for conditional branches in SelectionDAG, as well as JCCs generated by X86 backend. (PR #102101)

Tianqing Wang via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 7 22:53:26 PDT 2024


https://github.com/tianqingw updated https://github.com/llvm/llvm-project/pull/102101

>From da16934ccaed7ed54db72f62691c3ac6dbb7fdfd Mon Sep 17 00:00:00 2001
From: Tianqing Wang <tianqing.wang at intel.com>
Date: Tue, 6 Aug 2024 13:55:54 +0800
Subject: [PATCH] [SelectionDAG][X86] Preserve unpredictable metadata for
 conditional branches in SelectionDAG, as well as JCCs generated by X86
 backend.

This builds on 09515f2c2, which preserves unpredictable metadata in
CodeGen for `select`. This patch does it for conditional branches.
---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  5 ++++
 .../llvm/CodeGen/SwitchLoweringUtils.h        | 15 +++++++----
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  7 ++---
 .../lib/CodeGen/SelectionDAG/InstrEmitter.cpp | 12 ++++-----
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 26 ++++++++++++++++---
 .../SelectionDAG/SelectionDAGBuilder.cpp      | 14 ++++++----
 llvm/lib/Target/X86/X86ISelLowering.cpp       | 22 ++++++++--------
 llvm/test/CodeGen/X86/unpredictable-brcond.ll | 10 +++----
 8 files changed, 72 insertions(+), 39 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 1d0124ec755352..54561c2d1e3629 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1165,8 +1165,13 @@ class SelectionDAG {
                   SDValue N2, SDValue N3, const SDNodeFlags Flags);
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
                   SDValue N2, SDValue N3, SDValue N4);
+  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
+                  SDValue N2, SDValue N3, SDValue N4, const SDNodeFlags Flags);
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
                   SDValue N2, SDValue N3, SDValue N4, SDValue N5);
+  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
+                  SDValue N2, SDValue N3, SDValue N4, SDValue N5,
+                  const SDNodeFlags Flags);
 
   // Specialize again based on number of operands for nodes with a VTList
   // rather than a single VT.
diff --git a/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h b/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
index 99478e9f39e226..9282c4a771afb2 100644
--- a/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
+++ b/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
@@ -137,18 +137,21 @@ struct CaseBlock {
   SDLoc DL;
   DebugLoc DbgLoc;
 
-  // Branch weights.
+  // Branch weights and predictability.
   BranchProbability TrueProb, FalseProb;
+  bool IsUnpredictable;
 
   // Constructor for SelectionDAG.
   CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
             const Value *cmpmiddle, MachineBasicBlock *truebb,
             MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
             BranchProbability trueprob = BranchProbability::getUnknown(),
-            BranchProbability falseprob = BranchProbability::getUnknown())
+            BranchProbability falseprob = BranchProbability::getUnknown(),
+            bool isunpredictable = false)
       : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
         TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
-        TrueProb(trueprob), FalseProb(falseprob) {}
+        TrueProb(trueprob), FalseProb(falseprob),
+        IsUnpredictable(isunpredictable) {}
 
   // Constructor for GISel.
   CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs,
@@ -156,10 +159,12 @@ struct CaseBlock {
             MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
             MachineBasicBlock *me, DebugLoc dl,
             BranchProbability trueprob = BranchProbability::getUnknown(),
-            BranchProbability falseprob = BranchProbability::getUnknown())
+            BranchProbability falseprob = BranchProbability::getUnknown(),
+            bool isunpredictable = false)
       : PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle),
         CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
-        DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {}
+        DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob),
+        IsUnpredictable(isunpredictable) {}
 };
 
 struct JumpTable {
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a5b7397253e08d..9b96059ca674e7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -18164,7 +18164,7 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) {
   // nondeterministic jumps).
   if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
     return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
-                       N1->getOperand(0), N2);
+                       N1->getOperand(0), N2, N->getFlags());
   }
 
   // Variant of the previous fold where there is a SETCC in between:
@@ -18213,7 +18213,8 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) {
     if (Updated)
       return DAG.getNode(
           ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
-          DAG.getSetCC(SDLoc(N1), N1->getValueType(0), S0, S1, Cond), N2);
+          DAG.getSetCC(SDLoc(N1), N1->getValueType(0), S0, S1, Cond), N2,
+          N->getFlags());
   }
 
   // If N is a constant we could fold this into a fallthrough or unconditional
@@ -18238,7 +18239,7 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) {
     HandleSDNode ChainHandle(Chain);
     if (SDValue NewN1 = rebuildSetCC(N1))
       return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
-                         ChainHandle.getValue(), NewN1, N2);
+                         ChainHandle.getValue(), NewN1, N2, N->getFlags());
   }
 
   return SDValue();
diff --git a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
index 4ce92e156cf85e..db33d5242601e0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
@@ -1065,14 +1065,17 @@ EmitMachineNode(SDNode *Node, bool IsClone, bool IsCloned,
   // Create the new machine instruction.
   MachineInstrBuilder MIB = BuildMI(*MF, Node->getDebugLoc(), II);
 
+  // Transfer IR flags from the SDNode to the MachineInstr
+  MachineInstr *MI = MIB.getInstr();
+  const SDNodeFlags Flags = Node->getFlags();
+  if (Flags.hasUnpredictable())
+    MI->setFlag(MachineInstr::MIFlag::Unpredictable);
+
   // Add result register values for things that are defined by this
   // instruction.
   if (NumResults) {
     CreateVirtualRegisters(Node, MIB, II, IsClone, IsCloned, VRBaseMap);
 
-    // Transfer any IR flags from the SDNode to the MachineInstr
-    MachineInstr *MI = MIB.getInstr();
-    const SDNodeFlags Flags = Node->getFlags();
     if (Flags.hasNoSignedZeros())
       MI->setFlag(MachineInstr::MIFlag::FmNsz);
 
@@ -1105,9 +1108,6 @@ EmitMachineNode(SDNode *Node, bool IsClone, bool IsCloned,
 
     if (Flags.hasNoFPExcept())
       MI->setFlag(MachineInstr::MIFlag::NoFPExcept);
-
-    if (Flags.hasUnpredictable())
-      MI->setFlag(MachineInstr::MIFlag::Unpredictable);
   }
 
   // Emit all of the actual operands of this instruction, adding them to the
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 5c6a2454d664ba..0a35e89afac249 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -7613,16 +7613,34 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
-                              SDValue N1, SDValue N2, SDValue N3, SDValue N4) {
+                              SDValue N1, SDValue N2, SDValue N3, SDValue N4,
+                              const SDNodeFlags Flags) {
   SDValue Ops[] = { N1, N2, N3, N4 };
-  return getNode(Opcode, DL, VT, Ops);
+  return getNode(Opcode, DL, VT, Ops, Flags);
+}
+
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+                              SDValue N1, SDValue N2, SDValue N3, SDValue N4) {
+  SDNodeFlags Flags;
+  if (Inserter)
+    Flags = Inserter->getFlags();
+  return getNode(Opcode, DL, VT, N1, N2, N3, N4, Flags);
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
                               SDValue N1, SDValue N2, SDValue N3, SDValue N4,
-                              SDValue N5) {
+                              SDValue N5, const SDNodeFlags Flags) {
   SDValue Ops[] = { N1, N2, N3, N4, N5 };
-  return getNode(Opcode, DL, VT, Ops);
+  return getNode(Opcode, DL, VT, Ops, Flags);
+}
+
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+                              SDValue N1, SDValue N2, SDValue N3, SDValue N4,
+                              SDValue N5) {
+  SDNodeFlags Flags;
+  if (Inserter)
+    Flags = Inserter->getFlags();
+  return getNode(Opcode, DL, VT, N1, N2, N3, N4, N5, Flags);
 }
 
 /// getStackArgumentTokenFactor - Compute a TokenFactor to force all
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1f4436fb3a4966..2dfa5c2e767a2f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -2818,9 +2818,10 @@ void SelectionDAGBuilder::visitBr(const BranchInst &I) {
   //     je foo
   //     cmp D, E
   //     jle foo
+  bool IsUnpredictable = I.hasMetadata(LLVMContext::MD_unpredictable);
   const Instruction *BOp = dyn_cast<Instruction>(CondVal);
   if (!DAG.getTargetLoweringInfo().isJumpExpensive() && BOp &&
-      BOp->hasOneUse() && !I.hasMetadata(LLVMContext::MD_unpredictable)) {
+      BOp->hasOneUse() && !IsUnpredictable) {
     Value *Vec;
     const Value *BOp0, *BOp1;
     Instruction::BinaryOps Opcode = (Instruction::BinaryOps)0;
@@ -2869,7 +2870,9 @@ void SelectionDAGBuilder::visitBr(const BranchInst &I) {
 
   // Create a CaseBlock record representing this branch.
   CaseBlock CB(ISD::SETEQ, CondVal, ConstantInt::getTrue(*DAG.getContext()),
-               nullptr, Succ0MBB, Succ1MBB, BrMBB, getCurSDLoc());
+               nullptr, Succ0MBB, Succ1MBB, BrMBB, getCurSDLoc(),
+               BranchProbability::getUnknown(), BranchProbability::getUnknown(),
+               IsUnpredictable);
 
   // Use visitSwitchCase to actually insert the fast branch sequence for this
   // cond branch.
@@ -2957,9 +2960,10 @@ void SelectionDAGBuilder::visitSwitchCase(CaseBlock &CB,
     Cond = DAG.getNode(ISD::XOR, dl, Cond.getValueType(), Cond, True);
   }
 
-  SDValue BrCond = DAG.getNode(ISD::BRCOND, dl,
-                               MVT::Other, getControlRoot(), Cond,
-                               DAG.getBasicBlock(CB.TrueBB));
+  SDNodeFlags Flags;
+  Flags.setUnpredictable(CB.IsUnpredictable);
+  SDValue BrCond = DAG.getNode(ISD::BRCOND, dl, MVT::Other, getControlRoot(),
+                               Cond, DAG.getBasicBlock(CB.TrueBB), Flags);
 
   setValue(CurInst, BrCond);
 
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 2891e21be1b267..3fc632c37fc636 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -24827,14 +24827,14 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
 
       SDValue CCVal = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
       return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
-                         Overflow);
+                         Overflow, Op->getFlags());
     }
 
     if (LHS.getSimpleValueType().isInteger()) {
       SDValue CCVal;
       SDValue EFLAGS = emitFlagsForSetcc(LHS, RHS, CC, SDLoc(Cond), DAG, CCVal);
       return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
-                         EFLAGS);
+                         EFLAGS, Op->getFlags());
     }
 
     if (CC == ISD::SETOEQ) {
@@ -24860,10 +24860,10 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
               DAG.getNode(X86ISD::FCMP, SDLoc(Cond), MVT::i32, LHS, RHS);
           SDValue CCVal = DAG.getTargetConstant(X86::COND_NE, dl, MVT::i8);
           Chain = DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest,
-                              CCVal, Cmp);
+                              CCVal, Cmp, Op->getFlags());
           CCVal = DAG.getTargetConstant(X86::COND_P, dl, MVT::i8);
           return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
-                             Cmp);
+                             Cmp, Op->getFlags());
         }
       }
     } else if (CC == ISD::SETUNE) {
@@ -24872,18 +24872,18 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
       // separate test.
       SDValue Cmp = DAG.getNode(X86ISD::FCMP, SDLoc(Cond), MVT::i32, LHS, RHS);
       SDValue CCVal = DAG.getTargetConstant(X86::COND_NE, dl, MVT::i8);
-      Chain =
-          DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal, Cmp);
+      Chain = DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
+                          Cmp, Op->getFlags());
       CCVal = DAG.getTargetConstant(X86::COND_P, dl, MVT::i8);
       return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
-                         Cmp);
+                         Cmp, Op->getFlags());
     } else {
       X86::CondCode X86Cond =
           TranslateX86CC(CC, dl, /*IsFP*/ true, LHS, RHS, DAG);
       SDValue Cmp = DAG.getNode(X86ISD::FCMP, SDLoc(Cond), MVT::i32, LHS, RHS);
       SDValue CCVal = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
       return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
-                         Cmp);
+                         Cmp, Op->getFlags());
     }
   }
 
@@ -24894,7 +24894,7 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
 
     SDValue CCVal = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
     return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
-                       Overflow);
+                       Overflow, Op->getFlags());
   }
 
   // Look past the truncate if the high bits are known zero.
@@ -24913,8 +24913,8 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
 
   SDValue CCVal;
   SDValue EFLAGS = emitFlagsForSetcc(LHS, RHS, ISD::SETNE, dl, DAG, CCVal);
-  return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
-                     EFLAGS);
+  return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal, EFLAGS,
+                     Op->getFlags());
 }
 
 // Lower dynamic stack allocation to _alloca call for Cygwin/Mingw targets.
diff --git a/llvm/test/CodeGen/X86/unpredictable-brcond.ll b/llvm/test/CodeGen/X86/unpredictable-brcond.ll
index 6c894ea8277679..12411f1c49f2db 100644
--- a/llvm/test/CodeGen/X86/unpredictable-brcond.ll
+++ b/llvm/test/CodeGen/X86/unpredictable-brcond.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 5
-; Currently, unpredictable metadata on conditional branches is lost during CodeGen.
+; Make sure MIR generated for conditional branch with unpredictable metadata has unpredictable flag.
 ; RUN: llc -mtriple=x86_64-unknown-linux-gnu -stop-after=finalize-isel < %s | FileCheck %s
 
 define void @cond_branch_1(i1 %cond) {
@@ -11,7 +11,7 @@ define void @cond_branch_1(i1 %cond) {
   ; CHECK-NEXT:   [[COPY:%[0-9]+]]:gr32 = COPY $edi
   ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:gr8 = COPY [[COPY]].sub_8bit
   ; CHECK-NEXT:   TEST8ri killed [[COPY1]], 1, implicit-def $eflags
-  ; CHECK-NEXT:   JCC_1 %bb.2, 4, implicit $eflags
+  ; CHECK-NEXT:   unpredictable JCC_1 %bb.2, 4, implicit $eflags
   ; CHECK-NEXT:   JMP_1 %bb.1
   ; CHECK-NEXT: {{  $}}
   ; CHECK-NEXT: bb.1.then:
@@ -51,7 +51,7 @@ define void @cond_branch_2(double %a, double %b, i32 %c, i32 %d) nounwind {
   ; CHECK-NEXT:   [[SETCCr1:%[0-9]+]]:gr8 = SETCCr 6, implicit $eflags
   ; CHECK-NEXT:   [[OR8rr:%[0-9]+]]:gr8 = OR8rr [[SETCCr]], killed [[SETCCr1]], implicit-def dead $eflags
   ; CHECK-NEXT:   TEST8rr [[OR8rr]], [[OR8rr]], implicit-def $eflags
-  ; CHECK-NEXT:   JCC_1 %bb.2, 5, implicit $eflags
+  ; CHECK-NEXT:   unpredictable JCC_1 %bb.2, 5, implicit $eflags
   ; CHECK-NEXT:   JMP_1 %bb.1
   ; CHECK-NEXT: {{  $}}
   ; CHECK-NEXT: bb.1.true:
@@ -89,8 +89,8 @@ define void @isint_branch(double %d) nounwind {
   ; CHECK-NEXT:   [[CVTDQ2PDrr:%[0-9]+]]:vr128 = CVTDQ2PDrr killed [[CVTTPD2DQrr]]
   ; CHECK-NEXT:   [[COPY2:%[0-9]+]]:fr64 = COPY [[CVTDQ2PDrr]]
   ; CHECK-NEXT:   nofpexcept UCOMISDrr [[COPY]], killed [[COPY2]], implicit-def $eflags, implicit $mxcsr
-  ; CHECK-NEXT:   JCC_1 %bb.2, 5, implicit $eflags
-  ; CHECK-NEXT:   JCC_1 %bb.2, 10, implicit $eflags
+  ; CHECK-NEXT:   unpredictable JCC_1 %bb.2, 5, implicit $eflags
+  ; CHECK-NEXT:   unpredictable JCC_1 %bb.2, 10, implicit $eflags
   ; CHECK-NEXT:   JMP_1 %bb.1
   ; CHECK-NEXT: {{  $}}
   ; CHECK-NEXT: bb.1.true:



More information about the llvm-commits mailing list