[llvm] r337446 - [x86/SLH] Major refactoring of SLH implementaiton. There are two big

Chandler Carruth via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 19 04:13:58 PDT 2018


Author: chandlerc
Date: Thu Jul 19 04:13:58 2018
New Revision: 337446

URL: http://llvm.org/viewvc/llvm-project?rev=337446&view=rev
Log:
[x86/SLH] Major refactoring of SLH implementaiton. There are two big
changes that are intertwined here:

1) Extracting the tracing of predicate state through the CFG to its own
   function.
2) Creating a struct to manage the predicate state used throughout the
   pass.

Doing #1 necessitates and motivates the particular approach for #2 as
now the predicate management is spread across different functions
focused on different aspects of it. A number of simplifications then
fell out as a direct consequence.

I went with an Optional to make it more natural to construct the
MachineSSAUpdater object.

This is probably the single largest outstanding refactoring step I have.
Things get a bit more surgical from here. My current goal, beyond
generally making this maintainable long-term, is to implement several
improvements to how we do interprocedural tracking of predicate state.
But I don't want to do that until the predicate state management and
tracing is in reasonably clear state.

Differential Revision: https://reviews.llvm.org/D49427

Modified:
    llvm/trunk/lib/Target/X86/X86SpeculativeLoadHardening.cpp

Modified: llvm/trunk/lib/Target/X86/X86SpeculativeLoadHardening.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86SpeculativeLoadHardening.cpp?rev=337446&r1=337445&r2=337446&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86SpeculativeLoadHardening.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86SpeculativeLoadHardening.cpp Thu Jul 19 04:13:58 2018
@@ -26,6 +26,7 @@
 #include "X86Subtarget.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Optional.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -141,17 +142,33 @@ private:
     MachineInstr *UncondBr;
   };
 
+  /// Manages the predicate state traced through the program.
+  struct PredState {
+    unsigned InitialReg;
+    unsigned PoisonReg;
+
+    const TargetRegisterClass *RC;
+    MachineSSAUpdater SSA;
+
+    PredState(MachineFunction &MF, const TargetRegisterClass *RC)
+        : RC(RC), SSA(MF) {}
+  };
+
   const X86Subtarget *Subtarget;
   MachineRegisterInfo *MRI;
   const X86InstrInfo *TII;
   const TargetRegisterInfo *TRI;
-  const TargetRegisterClass *PredStateRC;
+
+  Optional<PredState> PS;
 
   void hardenEdgesWithLFENCE(MachineFunction &MF);
 
   SmallVector<BlockCondInfo, 16> collectBlockCondInfo(MachineFunction &MF);
 
-  void checkAllLoads(MachineFunction &MF, MachineSSAUpdater &PredStateSSA);
+  SmallVector<MachineInstr *, 16>
+  tracePredStateThroughCFG(MachineFunction &MF, ArrayRef<BlockCondInfo> Infos);
+
+  void checkAllLoads(MachineFunction &MF);
 
   unsigned saveEFLAGS(MachineBasicBlock &MBB,
                       MachineBasicBlock::iterator InsertPt, DebugLoc Loc);
@@ -168,15 +185,15 @@ private:
 
   void
   hardenLoadAddr(MachineInstr &MI, MachineOperand &BaseMO,
-                 MachineOperand &IndexMO, MachineSSAUpdater &PredStateSSA,
+                 MachineOperand &IndexMO,
                  SmallDenseMap<unsigned, unsigned, 32> &AddrRegToHardenedReg);
   MachineInstr *
   sinkPostLoadHardenedInst(MachineInstr &MI,
                            SmallPtrSetImpl<MachineInstr *> &HardenedInstrs);
   bool canHardenRegister(unsigned Reg);
-  void hardenPostLoad(MachineInstr &MI, MachineSSAUpdater &PredStateSSA);
-  void checkReturnInstr(MachineInstr &MI, MachineSSAUpdater &PredStateSSA);
-  void checkCallInstr(MachineInstr &MI, MachineSSAUpdater &PredStateSSA);
+  void hardenPostLoad(MachineInstr &MI);
+  void checkReturnInstr(MachineInstr &MI);
+  void checkCallInstr(MachineInstr &MI);
 };
 
 } // end anonymous namespace
@@ -371,8 +388,9 @@ bool X86SpeculativeLoadHardeningPass::ru
   MRI = &MF.getRegInfo();
   TII = Subtarget->getInstrInfo();
   TRI = Subtarget->getRegisterInfo();
+
   // FIXME: Support for 32-bit.
-  PredStateRC = &X86::GR64_NOSPRegClass;
+  PS.emplace(MF, &X86::GR64_NOSPRegClass);
 
   if (MF.begin() == MF.end())
     // Nothing to do for a degenerate empty function...
@@ -401,14 +419,11 @@ bool X86SpeculativeLoadHardeningPass::ru
   if (!HasVulnerableLoad && Infos.empty())
     return true;
 
-  unsigned PredStateReg;
-  unsigned PredStateSizeInBytes = TRI->getRegSizeInBits(*PredStateRC) / 8;
-
   // The poison value is required to be an all-ones value for many aspects of
   // this mitigation.
   const int PoisonVal = -1;
-  unsigned PoisonReg = MRI->createVirtualRegister(PredStateRC);
-  BuildMI(Entry, EntryInsertPt, Loc, TII->get(X86::MOV64ri32), PoisonReg)
+  PS->PoisonReg = MRI->createVirtualRegister(PS->RC);
+  BuildMI(Entry, EntryInsertPt, Loc, TII->get(X86::MOV64ri32), PS->PoisonReg)
       .addImm(PoisonVal);
   ++NumInstsInserted;
 
@@ -436,11 +451,11 @@ bool X86SpeculativeLoadHardeningPass::ru
   if (HardenInterprocedurally && !FenceCallAndRet) {
     // Set up the predicate state by extracting it from the incoming stack
     // pointer so we pick up any misspeculation in our caller.
-    PredStateReg = extractPredStateFromSP(Entry, EntryInsertPt, Loc);
+    PS->InitialReg = extractPredStateFromSP(Entry, EntryInsertPt, Loc);
   } else {
     // Otherwise, just build the predicate state itself by zeroing a register
     // as we don't need any initial state.
-    PredStateReg = MRI->createVirtualRegister(PredStateRC);
+    PS->InitialReg = MRI->createVirtualRegister(PS->RC);
     unsigned PredStateSubReg = MRI->createVirtualRegister(&X86::GR32RegClass);
     auto ZeroI = BuildMI(Entry, EntryInsertPt, Loc, TII->get(X86::MOV32r0),
                          PredStateSubReg);
@@ -451,7 +466,7 @@ bool X86SpeculativeLoadHardeningPass::ru
            "Must have an implicit def of EFLAGS!");
     ZeroEFLAGSDefOp->setIsDead(true);
     BuildMI(Entry, EntryInsertPt, Loc, TII->get(X86::SUBREG_TO_REG),
-            PredStateReg)
+            PS->InitialReg)
         .addImm(0)
         .addReg(PredStateSubReg)
         .addImm(X86::sub_32bit);
@@ -464,142 +479,11 @@ bool X86SpeculativeLoadHardeningPass::ru
 
   // Track the updated values in an SSA updater to rewrite into SSA form at the
   // end.
-  MachineSSAUpdater PredStateSSA(MF);
-  PredStateSSA.Initialize(PredStateReg);
-  PredStateSSA.AddAvailableValue(&Entry, PredStateReg);
-  // Collect the inserted instructions so we can rewrite their uses of the
-  // predicate state into SSA form.
-  SmallVector<MachineInstr *, 16> CMovs;
-
-  // Now walk all of the basic blocks looking for ones that end in conditional
-  // jumps where we need to update this register along each edge.
-  for (BlockCondInfo &Info : Infos) {
-    MachineBasicBlock &MBB = *Info.MBB;
-    SmallVectorImpl<MachineInstr *> &CondBrs = Info.CondBrs;
-    MachineInstr *UncondBr = Info.UncondBr;
-
-    LLVM_DEBUG(dbgs() << "Tracing predicate through block: " << MBB.getName()
-                      << "\n");
-    ++NumCondBranchesTraced;
-
-    // Compute the non-conditional successor as either the target of any
-    // unconditional branch or the layout successor.
-    MachineBasicBlock *UncondSucc =
-        UncondBr ? (UncondBr->getOpcode() == X86::JMP_1
-                        ? UncondBr->getOperand(0).getMBB()
-                        : nullptr)
-                 : &*std::next(MachineFunction::iterator(&MBB));
-
-    // Count how many edges there are to any given successor.
-    SmallDenseMap<MachineBasicBlock *, int> SuccCounts;
-    if (UncondSucc)
-      ++SuccCounts[UncondSucc];
-    for (auto *CondBr : CondBrs)
-      ++SuccCounts[CondBr->getOperand(0).getMBB()];
-
-    // A lambda to insert cmov instructions into a block checking all of the
-    // condition codes in a sequence.
-    auto BuildCheckingBlockForSuccAndConds =
-        [&](MachineBasicBlock &MBB, MachineBasicBlock &Succ, int SuccCount,
-            MachineInstr *Br, MachineInstr *&UncondBr,
-            ArrayRef<X86::CondCode> Conds) {
-          // First, we split the edge to insert the checking block into a safe
-          // location.
-          auto &CheckingMBB =
-              (SuccCount == 1 && Succ.pred_size() == 1)
-                  ? Succ
-                  : splitEdge(MBB, Succ, SuccCount, Br, UncondBr, *TII);
-
-          bool LiveEFLAGS = Succ.isLiveIn(X86::EFLAGS);
-          if (!LiveEFLAGS)
-            CheckingMBB.addLiveIn(X86::EFLAGS);
-
-          // Now insert the cmovs to implement the checks.
-          auto InsertPt = CheckingMBB.begin();
-          assert((InsertPt == CheckingMBB.end() || !InsertPt->isPHI()) &&
-                 "Should never have a PHI in the initial checking block as it "
-                 "always has a single predecessor!");
-
-          // We will wire each cmov to each other, but need to start with the
-          // incoming pred state.
-          unsigned CurStateReg = PredStateReg;
-
-          for (X86::CondCode Cond : Conds) {
-            auto CMovOp = X86::getCMovFromCond(Cond, PredStateSizeInBytes);
-
-            unsigned UpdatedStateReg = MRI->createVirtualRegister(PredStateRC);
-            auto CMovI = BuildMI(CheckingMBB, InsertPt, Loc, TII->get(CMovOp),
-                                 UpdatedStateReg)
-                             .addReg(CurStateReg)
-                             .addReg(PoisonReg);
-            // If this is the last cmov and the EFLAGS weren't originally
-            // live-in, mark them as killed.
-            if (!LiveEFLAGS && Cond == Conds.back())
-              CMovI->findRegisterUseOperand(X86::EFLAGS)->setIsKill(true);
-
-            ++NumInstsInserted;
-            LLVM_DEBUG(dbgs() << "  Inserting cmov: "; CMovI->dump();
-                       dbgs() << "\n");
-
-            // The first one of the cmovs will be using the top level
-            // `PredStateReg` and need to get rewritten into SSA form.
-            if (CurStateReg == PredStateReg)
-              CMovs.push_back(&*CMovI);
-
-            // The next cmov should start from this one's def.
-            CurStateReg = UpdatedStateReg;
-          }
-
-          // And put the last one into the available values for PredStateSSA.
-          PredStateSSA.AddAvailableValue(&CheckingMBB, CurStateReg);
-        };
-
-    std::vector<X86::CondCode> UncondCodeSeq;
-    for (auto *CondBr : CondBrs) {
-      MachineBasicBlock &Succ = *CondBr->getOperand(0).getMBB();
-      int &SuccCount = SuccCounts[&Succ];
-
-      X86::CondCode Cond = X86::getCondFromBranchOpc(CondBr->getOpcode());
-      X86::CondCode InvCond = X86::GetOppositeBranchCondition(Cond);
-      UncondCodeSeq.push_back(Cond);
-
-      BuildCheckingBlockForSuccAndConds(MBB, Succ, SuccCount, CondBr, UncondBr,
-                                        {InvCond});
-
-      // Decrement the successor count now that we've split one of the edges.
-      // We need to keep the count of edges to the successor accurate in order
-      // to know above when to *replace* the successor in the CFG vs. just
-      // adding the new successor.
-      --SuccCount;
-    }
-
-    // Since we may have split edges and changed the number of successors,
-    // normalize the probabilities. This avoids doing it each time we split an
-    // edge.
-    MBB.normalizeSuccProbs();
-
-    // Finally, we need to insert cmovs into the "fallthrough" edge. Here, we
-    // need to intersect the other condition codes. We can do this by just
-    // doing a cmov for each one.
-    if (!UncondSucc)
-      // If we have no fallthrough to protect (perhaps it is an indirect jump?)
-      // just skip this and continue.
-      continue;
-
-    assert(SuccCounts[UncondSucc] == 1 &&
-           "We should never have more than one edge to the unconditional "
-           "successor at this point because every other edge must have been "
-           "split above!");
+  PS->SSA.Initialize(PS->InitialReg);
+  PS->SSA.AddAvailableValue(&Entry, PS->InitialReg);
 
-    // Sort and unique the codes to minimize them.
-    llvm::sort(UncondCodeSeq.begin(), UncondCodeSeq.end());
-    UncondCodeSeq.erase(std::unique(UncondCodeSeq.begin(), UncondCodeSeq.end()),
-                        UncondCodeSeq.end());
-
-    // Build a checking version of the successor.
-    BuildCheckingBlockForSuccAndConds(MBB, *UncondSucc, /*SuccCount*/ 1,
-                                      UncondBr, UncondBr, UncondCodeSeq);
-  }
+  // Trace through the CFG.
+  auto CMovs = tracePredStateThroughCFG(MF, Infos);
 
   // We may also enter basic blocks in this function via exception handling
   // control flow. Here, if we are hardening interprocedurally, we need to
@@ -615,23 +499,23 @@ bool X86SpeculativeLoadHardeningPass::ru
       assert(!MBB.isCleanupFuncletEntry() && "Only Itanium ABI EH supported!");
       if (!MBB.isEHPad())
         continue;
-      PredStateSSA.AddAvailableValue(
+      PS->SSA.AddAvailableValue(
           &MBB,
           extractPredStateFromSP(MBB, MBB.SkipPHIsAndLabels(MBB.begin()), Loc));
     }
   }
 
   // Now check all of the loads using the predicate state.
-  checkAllLoads(MF, PredStateSSA);
+  checkAllLoads(MF);
 
   // Now rewrite all the uses of the pred state using the SSA updater so that
   // we track updates through the CFG.
   for (MachineInstr *CMovI : CMovs)
     for (MachineOperand &Op : CMovI->operands()) {
-      if (!Op.isReg() || Op.getReg() != PredStateReg)
+      if (!Op.isReg() || Op.getReg() != PS->InitialReg)
         continue;
 
-      PredStateSSA.RewriteUse(Op);
+      PS->SSA.RewriteUse(Op);
     }
 
   // If we are hardening interprocedurally, find each returning block and
@@ -645,7 +529,7 @@ bool X86SpeculativeLoadHardeningPass::ru
       if (!MI.isReturn())
         continue;
 
-      checkReturnInstr(MI, PredStateSSA);
+      checkReturnInstr(MI);
     }
 
   LLVM_DEBUG(dbgs() << "Final speculative load hardened function:\n"; MF.dump();
@@ -776,6 +660,157 @@ X86SpeculativeLoadHardeningPass::collect
   return Infos;
 }
 
+/// Trace the predicate state through the CFG, instrumenting each conditional
+/// branch such that misspeculation through an edge will poison the predicate
+/// state.
+///
+/// Returns the list of inserted CMov instructions so that they can have their
+/// uses of the predicate state rewritten into proper SSA form once it is
+/// complete.
+SmallVector<MachineInstr *, 16>
+X86SpeculativeLoadHardeningPass::tracePredStateThroughCFG(
+    MachineFunction &MF, ArrayRef<BlockCondInfo> Infos) {
+  // Collect the inserted cmov instructions so we can rewrite their uses of the
+  // predicate state into SSA form.
+  SmallVector<MachineInstr *, 16> CMovs;
+
+  // Now walk all of the basic blocks looking for ones that end in conditional
+  // jumps where we need to update this register along each edge.
+  for (const BlockCondInfo &Info : Infos) {
+    MachineBasicBlock &MBB = *Info.MBB;
+    const SmallVectorImpl<MachineInstr *> &CondBrs = Info.CondBrs;
+    MachineInstr *UncondBr = Info.UncondBr;
+
+    LLVM_DEBUG(dbgs() << "Tracing predicate through block: " << MBB.getName()
+                      << "\n");
+    ++NumCondBranchesTraced;
+
+    // Compute the non-conditional successor as either the target of any
+    // unconditional branch or the layout successor.
+    MachineBasicBlock *UncondSucc =
+        UncondBr ? (UncondBr->getOpcode() == X86::JMP_1
+                        ? UncondBr->getOperand(0).getMBB()
+                        : nullptr)
+                 : &*std::next(MachineFunction::iterator(&MBB));
+
+    // Count how many edges there are to any given successor.
+    SmallDenseMap<MachineBasicBlock *, int> SuccCounts;
+    if (UncondSucc)
+      ++SuccCounts[UncondSucc];
+    for (auto *CondBr : CondBrs)
+      ++SuccCounts[CondBr->getOperand(0).getMBB()];
+
+    // A lambda to insert cmov instructions into a block checking all of the
+    // condition codes in a sequence.
+    auto BuildCheckingBlockForSuccAndConds =
+        [&](MachineBasicBlock &MBB, MachineBasicBlock &Succ, int SuccCount,
+            MachineInstr *Br, MachineInstr *&UncondBr,
+            ArrayRef<X86::CondCode> Conds) {
+          // First, we split the edge to insert the checking block into a safe
+          // location.
+          auto &CheckingMBB =
+              (SuccCount == 1 && Succ.pred_size() == 1)
+                  ? Succ
+                  : splitEdge(MBB, Succ, SuccCount, Br, UncondBr, *TII);
+
+          bool LiveEFLAGS = Succ.isLiveIn(X86::EFLAGS);
+          if (!LiveEFLAGS)
+            CheckingMBB.addLiveIn(X86::EFLAGS);
+
+          // Now insert the cmovs to implement the checks.
+          auto InsertPt = CheckingMBB.begin();
+          assert((InsertPt == CheckingMBB.end() || !InsertPt->isPHI()) &&
+                 "Should never have a PHI in the initial checking block as it "
+                 "always has a single predecessor!");
+
+          // We will wire each cmov to each other, but need to start with the
+          // incoming pred state.
+          unsigned CurStateReg = PS->InitialReg;
+
+          for (X86::CondCode Cond : Conds) {
+            int PredStateSizeInBytes = TRI->getRegSizeInBits(*PS->RC) / 8;
+            auto CMovOp = X86::getCMovFromCond(Cond, PredStateSizeInBytes);
+
+            unsigned UpdatedStateReg = MRI->createVirtualRegister(PS->RC);
+            // Note that we intentionally use an empty debug location so that
+            // this picks up the preceding location.
+            auto CMovI = BuildMI(CheckingMBB, InsertPt, DebugLoc(),
+                                 TII->get(CMovOp), UpdatedStateReg)
+                             .addReg(CurStateReg)
+                             .addReg(PS->PoisonReg);
+            // If this is the last cmov and the EFLAGS weren't originally
+            // live-in, mark them as killed.
+            if (!LiveEFLAGS && Cond == Conds.back())
+              CMovI->findRegisterUseOperand(X86::EFLAGS)->setIsKill(true);
+
+            ++NumInstsInserted;
+            LLVM_DEBUG(dbgs() << "  Inserting cmov: "; CMovI->dump();
+                       dbgs() << "\n");
+
+            // The first one of the cmovs will be using the top level
+            // `PredStateReg` and need to get rewritten into SSA form.
+            if (CurStateReg == PS->InitialReg)
+              CMovs.push_back(&*CMovI);
+
+            // The next cmov should start from this one's def.
+            CurStateReg = UpdatedStateReg;
+          }
+
+          // And put the last one into the available values for SSA form of our
+          // predicate state.
+          PS->SSA.AddAvailableValue(&CheckingMBB, CurStateReg);
+        };
+
+    std::vector<X86::CondCode> UncondCodeSeq;
+    for (auto *CondBr : CondBrs) {
+      MachineBasicBlock &Succ = *CondBr->getOperand(0).getMBB();
+      int &SuccCount = SuccCounts[&Succ];
+
+      X86::CondCode Cond = X86::getCondFromBranchOpc(CondBr->getOpcode());
+      X86::CondCode InvCond = X86::GetOppositeBranchCondition(Cond);
+      UncondCodeSeq.push_back(Cond);
+
+      BuildCheckingBlockForSuccAndConds(MBB, Succ, SuccCount, CondBr, UncondBr,
+                                        {InvCond});
+
+      // Decrement the successor count now that we've split one of the edges.
+      // We need to keep the count of edges to the successor accurate in order
+      // to know above when to *replace* the successor in the CFG vs. just
+      // adding the new successor.
+      --SuccCount;
+    }
+
+    // Since we may have split edges and changed the number of successors,
+    // normalize the probabilities. This avoids doing it each time we split an
+    // edge.
+    MBB.normalizeSuccProbs();
+
+    // Finally, we need to insert cmovs into the "fallthrough" edge. Here, we
+    // need to intersect the other condition codes. We can do this by just
+    // doing a cmov for each one.
+    if (!UncondSucc)
+      // If we have no fallthrough to protect (perhaps it is an indirect jump?)
+      // just skip this and continue.
+      continue;
+
+    assert(SuccCounts[UncondSucc] == 1 &&
+           "We should never have more than one edge to the unconditional "
+           "successor at this point because every other edge must have been "
+           "split above!");
+
+    // Sort and unique the codes to minimize them.
+    llvm::sort(UncondCodeSeq.begin(), UncondCodeSeq.end());
+    UncondCodeSeq.erase(std::unique(UncondCodeSeq.begin(), UncondCodeSeq.end()),
+                        UncondCodeSeq.end());
+
+    // Build a checking version of the successor.
+    BuildCheckingBlockForSuccAndConds(MBB, *UncondSucc, /*SuccCount*/ 1,
+                                      UncondBr, UncondBr, UncondCodeSeq);
+  }
+
+  return CMovs;
+}
+
 /// Returns true if the instruction has no behavior (specified or otherwise)
 /// that is based on the value of any of its register operands
 ///
@@ -1190,8 +1225,7 @@ static bool isEFLAGSLive(MachineBasicBlo
   return MBB.isLiveIn(X86::EFLAGS);
 }
 
-void X86SpeculativeLoadHardeningPass::checkAllLoads(
-    MachineFunction &MF, MachineSSAUpdater &PredStateSSA) {
+void X86SpeculativeLoadHardeningPass::checkAllLoads(MachineFunction &MF) {
   // If the actual checking of loads is disabled, skip doing anything here.
   if (!HardenLoads)
     return;
@@ -1345,7 +1379,7 @@ void X86SpeculativeLoadHardeningPass::ch
             MI.getOperand(MemRefBeginIdx + X86::AddrBaseReg);
         MachineOperand &IndexMO =
             MI.getOperand(MemRefBeginIdx + X86::AddrIndexReg);
-        hardenLoadAddr(MI, BaseMO, IndexMO, PredStateSSA, AddrRegToHardenedReg);
+        hardenLoadAddr(MI, BaseMO, IndexMO, AddrRegToHardenedReg);
         continue;
       }
 
@@ -1381,7 +1415,7 @@ void X86SpeculativeLoadHardeningPass::ch
         AddrRegToHardenedReg[MI.getOperand(0).getReg()] =
             MI.getOperand(0).getReg();
 
-        hardenPostLoad(MI, PredStateSSA);
+        hardenPostLoad(MI);
         continue;
       }
 
@@ -1401,7 +1435,7 @@ void X86SpeculativeLoadHardeningPass::ch
       // First, we transfer the predicate state into the called function by
       // merging it into the stack pointer. This will kill the current def of
       // the state.
-      unsigned StateReg = PredStateSSA.GetValueAtEndOfBlock(&MBB);
+      unsigned StateReg = PS->SSA.GetValueAtEndOfBlock(&MBB);
       mergePredStateIntoSP(MBB, InsertPt, Loc, StateReg);
 
       // If this call is also a return (because it is a tail call) we're done.
@@ -1412,7 +1446,7 @@ void X86SpeculativeLoadHardeningPass::ch
       // state from SP after the return, and make this new state available.
       ++InsertPt;
       unsigned NewStateReg = extractPredStateFromSP(MBB, InsertPt, Loc);
-      PredStateSSA.AddAvailableValue(&MBB, NewStateReg);
+      PS->SSA.AddAvailableValue(&MBB, NewStateReg);
     }
 
     HardenPostLoad.clear();
@@ -1465,7 +1499,7 @@ void X86SpeculativeLoadHardeningPass::re
 void X86SpeculativeLoadHardeningPass::mergePredStateIntoSP(
     MachineBasicBlock &MBB, MachineBasicBlock::iterator InsertPt, DebugLoc Loc,
     unsigned PredStateReg) {
-  unsigned TmpReg = MRI->createVirtualRegister(PredStateRC);
+  unsigned TmpReg = MRI->createVirtualRegister(PS->RC);
   // FIXME: This hard codes a shift distance based on the number of bits needed
   // to stay canonical on 64-bit. We should compute this somehow and support
   // 32-bit as part of that.
@@ -1485,8 +1519,8 @@ void X86SpeculativeLoadHardeningPass::me
 unsigned X86SpeculativeLoadHardeningPass::extractPredStateFromSP(
     MachineBasicBlock &MBB, MachineBasicBlock::iterator InsertPt,
     DebugLoc Loc) {
-  unsigned PredStateReg = MRI->createVirtualRegister(PredStateRC);
-  unsigned TmpReg = MRI->createVirtualRegister(PredStateRC);
+  unsigned PredStateReg = MRI->createVirtualRegister(PS->RC);
+  unsigned TmpReg = MRI->createVirtualRegister(PS->RC);
 
   // We know that the stack pointer will have any preserved predicate state in
   // its high bit. We just want to smear this across the other bits. Turns out,
@@ -1496,7 +1530,7 @@ unsigned X86SpeculativeLoadHardeningPass
   auto ShiftI =
       BuildMI(MBB, InsertPt, Loc, TII->get(X86::SAR64ri), PredStateReg)
           .addReg(TmpReg, RegState::Kill)
-          .addImm(TRI->getRegSizeInBits(*PredStateRC) - 1);
+          .addImm(TRI->getRegSizeInBits(*PS->RC) - 1);
   ShiftI->addRegisterDead(X86::EFLAGS, TRI);
   ++NumInstsInserted;
 
@@ -1505,7 +1539,6 @@ unsigned X86SpeculativeLoadHardeningPass
 
 void X86SpeculativeLoadHardeningPass::hardenLoadAddr(
     MachineInstr &MI, MachineOperand &BaseMO, MachineOperand &IndexMO,
-    MachineSSAUpdater &PredStateSSA,
     SmallDenseMap<unsigned, unsigned, 32> &AddrRegToHardenedReg) {
   MachineBasicBlock &MBB = *MI.getParent();
   DebugLoc Loc = MI.getDebugLoc();
@@ -1570,7 +1603,7 @@ void X86SpeculativeLoadHardeningPass::ha
     return;
 
   // Compute the current predicate state.
-  unsigned StateReg = PredStateSSA.GetValueAtEndOfBlock(&MBB);
+  unsigned StateReg = PS->SSA.GetValueAtEndOfBlock(&MBB);
 
   auto InsertPt = MI.getIterator();
 
@@ -1818,8 +1851,7 @@ bool X86SpeculativeLoadHardeningPass::ca
 // this because having the poison value be all ones allows us to use the same
 // value below. And the goal is just for the loaded bits to not be exposed to
 // execution and coercing them to one is sufficient.
-void X86SpeculativeLoadHardeningPass::hardenPostLoad(
-    MachineInstr &MI, MachineSSAUpdater &PredStateSSA) {
+void X86SpeculativeLoadHardeningPass::hardenPostLoad(MachineInstr &MI) {
   MachineBasicBlock &MBB = *MI.getParent();
   DebugLoc Loc = MI.getDebugLoc();
 
@@ -1838,7 +1870,7 @@ void X86SpeculativeLoadHardeningPass::ha
   unsigned SubRegImms[] = {X86::sub_8bit, X86::sub_16bit, X86::sub_32bit};
 
   auto GetStateRegInRC = [&](const TargetRegisterClass &RC) {
-    unsigned StateReg = PredStateSSA.GetValueAtEndOfBlock(&MBB);
+    unsigned StateReg = PS->SSA.GetValueAtEndOfBlock(&MBB);
 
     int Bytes = TRI->getRegSizeInBits(RC) / 8;
     // FIXME: Need to teach this about 32-bit mode.
@@ -1937,8 +1969,7 @@ void X86SpeculativeLoadHardeningPass::ha
   ++NumPostLoadRegsHardened;
 }
 
-void X86SpeculativeLoadHardeningPass::checkReturnInstr(
-    MachineInstr &MI, MachineSSAUpdater &PredStateSSA) {
+void X86SpeculativeLoadHardeningPass::checkReturnInstr(MachineInstr &MI) {
   MachineBasicBlock &MBB = *MI.getParent();
   DebugLoc Loc = MI.getDebugLoc();
   auto InsertPt = MI.getIterator();
@@ -1962,8 +1993,7 @@ void X86SpeculativeLoadHardeningPass::ch
   // Take our predicate state, shift it to the high 17 bits (so that we keep
   // pointers canonical) and merge it into RSP. This will allow the caller to
   // extract it when we return (speculatively).
-  mergePredStateIntoSP(MBB, InsertPt, Loc,
-                       PredStateSSA.GetValueAtEndOfBlock(&MBB));
+  mergePredStateIntoSP(MBB, InsertPt, Loc, PS->SSA.GetValueAtEndOfBlock(&MBB));
 }
 
 INITIALIZE_PASS_BEGIN(X86SpeculativeLoadHardeningPass, DEBUG_TYPE,




More information about the llvm-commits mailing list