[llvm] ec1445c - [X86] Fix for ballooning compile times due to Load Value Injection (LVI) mitigations

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 30 17:43:59 PDT 2020


Author: Scott Constable
Date: 2020-07-30T17:22:33-07:00
New Revision: ec1445c5afda7f145a414f11c9103c87a4c1823f

URL: https://github.com/llvm/llvm-project/commit/ec1445c5afda7f145a414f11c9103c87a4c1823f
DIFF: https://github.com/llvm/llvm-project/commit/ec1445c5afda7f145a414f11c9103c87a4c1823f.diff

LOG: [X86] Fix for ballooning compile times due to Load Value Injection (LVI) mitigations

Fix for the issue raised in https://github.com/rust-lang/rust/issues/74632.

The current heuristic for inserting LFENCEs uses a quadratic-time algorithm. This can apparently cause substantial compilation slowdowns for building Rust projects, where functions > 5000 LoC are apparently common.

The updated heuristic in this patch implements a linear-time algorithm. On a set of benchmarks, the slowdown factor for the generated code was comparable (2.55x geo mean for the quadratic-time heuristic, vs. 2.58x for the linear-time heuristic). Both heuristics offer the same security properties, namely, mitigating LVI.

This patch also includes some formatting fixes.

Differential Revision: https://reviews.llvm.org/D84471

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp b/llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp
index 50f8b3477acc..18fcc48bc9cd 100644
--- a/llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp
+++ b/llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp
@@ -42,6 +42,7 @@
 #include "X86TargetMachine.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/StringRef.h"
@@ -104,9 +105,9 @@ static cl::opt<bool> EmitDotVerify(
     cl::init(false), cl::Hidden);
 
 static llvm::sys::DynamicLibrary OptimizeDL;
-typedef int (*OptimizeCutT)(unsigned int *nodes, unsigned int nodes_size,
-                            unsigned int *edges, int *edge_values,
-                            int *cut_edges /* out */, unsigned int edges_size);
+typedef int (*OptimizeCutT)(unsigned int *Nodes, unsigned int NodesSize,
+                            unsigned int *Edges, int *EdgeValues,
+                            int *CutEdges /* out */, unsigned int EdgesSize);
 static OptimizeCutT OptimizeCut = nullptr;
 
 namespace {
@@ -148,9 +149,10 @@ class X86LoadValueInjectionLoadHardeningPass : public MachineFunctionPass {
 
 private:
   using GraphBuilder = ImmutableGraphBuilder<MachineGadgetGraph>;
+  using Edge = MachineGadgetGraph::Edge;
+  using Node = MachineGadgetGraph::Node;
   using EdgeSet = MachineGadgetGraph::EdgeSet;
   using NodeSet = MachineGadgetGraph::NodeSet;
-  using Gadget = std::pair<MachineInstr *, MachineInstr *>;
 
   const X86Subtarget *STI;
   const TargetInstrInfo *TII;
@@ -162,8 +164,8 @@ class X86LoadValueInjectionLoadHardeningPass : public MachineFunctionPass {
                  const MachineDominanceFrontier &MDF) const;
   int hardenLoadsWithPlugin(MachineFunction &MF,
                             std::unique_ptr<MachineGadgetGraph> Graph) const;
-  int hardenLoadsWithGreedyHeuristic(
-      MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const;
+  int hardenLoadsWithHeuristic(MachineFunction &MF,
+                               std::unique_ptr<MachineGadgetGraph> Graph) const;
   int elimMitigatedEdgesAndNodes(MachineGadgetGraph &G,
                                  EdgeSet &ElimEdges /* in, out */,
                                  NodeSet &ElimNodes /* in, out */) const;
@@ -198,7 +200,7 @@ struct DOTGraphTraits<MachineGadgetGraph *> : DefaultDOTGraphTraits {
   using ChildIteratorType = typename Traits::ChildIteratorType;
   using ChildEdgeIteratorType = typename Traits::ChildEdgeIteratorType;
 
-  DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
+  DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {}
 
   std::string getNodeLabel(NodeRef Node, GraphType *) {
     if (Node->getValue() == MachineGadgetGraph::ArgNodeSentinel)
@@ -243,7 +245,7 @@ void X86LoadValueInjectionLoadHardeningPass::getAnalysisUsage(
   AU.setPreservesCFG();
 }
 
-static void WriteGadgetGraph(raw_ostream &OS, MachineFunction &MF,
+static void writeGadgetGraph(raw_ostream &OS, MachineFunction &MF,
                              MachineGadgetGraph *G) {
   WriteGraph(OS, G, /*ShortNames*/ false,
              "Speculative gadgets for \"" + MF.getName() + "\" function");
@@ -279,7 +281,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
     return false; // didn't find any gadgets
 
   if (EmitDotVerify) {
-    WriteGadgetGraph(outs(), MF, Graph.get());
+    writeGadgetGraph(outs(), MF, Graph.get());
     return false;
   }
 
@@ -292,7 +294,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
     raw_fd_ostream FileOut(FileName, FileError);
     if (FileError)
       errs() << FileError.message();
-    WriteGadgetGraph(FileOut, MF, Graph.get());
+    writeGadgetGraph(FileOut, MF, Graph.get());
     FileOut.close();
     LLVM_DEBUG(dbgs() << "Emitting gadget graph... Done\n");
     if (EmitDotOnly)
@@ -313,7 +315,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
     }
     FencesInserted = hardenLoadsWithPlugin(MF, std::move(Graph));
   } else { // Use the default greedy heuristic
-    FencesInserted = hardenLoadsWithGreedyHeuristic(MF, std::move(Graph));
+    FencesInserted = hardenLoadsWithHeuristic(MF, std::move(Graph));
   }
 
   if (FencesInserted > 0)
@@ -540,17 +542,17 @@ X86LoadValueInjectionLoadHardeningPass::getGadgetGraph(
 
 // Returns the number of remaining gadget edges that could not be eliminated
 int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes(
-    MachineGadgetGraph &G, MachineGadgetGraph::EdgeSet &ElimEdges /* in, out */,
-    MachineGadgetGraph::NodeSet &ElimNodes /* in, out */) const {
+    MachineGadgetGraph &G, EdgeSet &ElimEdges /* in, out */,
+    NodeSet &ElimNodes /* in, out */) const {
   if (G.NumFences > 0) {
     // Eliminate fences and CFG edges that ingress and egress the fence, as
     // they are trivially mitigated.
-    for (const auto &E : G.edges()) {
-      const MachineGadgetGraph::Node *Dest = E.getDest();
+    for (const Edge &E : G.edges()) {
+      const Node *Dest = E.getDest();
       if (isFence(Dest->getValue())) {
         ElimNodes.insert(*Dest);
         ElimEdges.insert(E);
-        for (const auto &DE : Dest->edges())
+        for (const Edge &DE : Dest->edges())
           ElimEdges.insert(DE);
       }
     }
@@ -558,29 +560,28 @@ int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes(
 
   // Find and eliminate gadget edges that have been mitigated.
   int MitigatedGadgets = 0, RemainingGadgets = 0;
-  MachineGadgetGraph::NodeSet ReachableNodes{G};
-  for (const auto &RootN : G.nodes()) {
+  NodeSet ReachableNodes{G};
+  for (const Node &RootN : G.nodes()) {
     if (llvm::none_of(RootN.edges(), MachineGadgetGraph::isGadgetEdge))
       continue; // skip this node if it isn't a gadget source
 
     // Find all of the nodes that are CFG-reachable from RootN using DFS
     ReachableNodes.clear();
-    std::function<void(const MachineGadgetGraph::Node *, bool)>
-        FindReachableNodes =
-            [&](const MachineGadgetGraph::Node *N, bool FirstNode) {
-              if (!FirstNode)
-                ReachableNodes.insert(*N);
-              for (const auto &E : N->edges()) {
-                const MachineGadgetGraph::Node *Dest = E.getDest();
-                if (MachineGadgetGraph::isCFGEdge(E) &&
-                    !ElimEdges.contains(E) && !ReachableNodes.contains(*Dest))
-                  FindReachableNodes(Dest, false);
-              }
-            };
+    std::function<void(const Node *, bool)> FindReachableNodes =
+        [&](const Node *N, bool FirstNode) {
+          if (!FirstNode)
+            ReachableNodes.insert(*N);
+          for (const Edge &E : N->edges()) {
+            const Node *Dest = E.getDest();
+            if (MachineGadgetGraph::isCFGEdge(E) && !ElimEdges.contains(E) &&
+                !ReachableNodes.contains(*Dest))
+              FindReachableNodes(Dest, false);
+          }
+        };
     FindReachableNodes(&RootN, true);
 
     // Any gadget whose sink is unreachable has been mitigated
-    for (const auto &E : RootN.edges()) {
+    for (const Edge &E : RootN.edges()) {
       if (MachineGadgetGraph::isGadgetEdge(E)) {
         if (ReachableNodes.contains(*E.getDest())) {
           // This gadget's sink is reachable
@@ -598,8 +599,8 @@ int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes(
 std::unique_ptr<MachineGadgetGraph>
 X86LoadValueInjectionLoadHardeningPass::trimMitigatedEdges(
     std::unique_ptr<MachineGadgetGraph> Graph) const {
-  MachineGadgetGraph::NodeSet ElimNodes{*Graph};
-  MachineGadgetGraph::EdgeSet ElimEdges{*Graph};
+  NodeSet ElimNodes{*Graph};
+  EdgeSet ElimEdges{*Graph};
   int RemainingGadgets =
       elimMitigatedEdgesAndNodes(*Graph, ElimEdges, ElimNodes);
   if (ElimEdges.empty() && ElimNodes.empty()) {
@@ -630,11 +631,11 @@ int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin(
     auto Edges = std::make_unique<unsigned int[]>(Graph->edges_size());
     auto EdgeCuts = std::make_unique<int[]>(Graph->edges_size());
     auto EdgeValues = std::make_unique<int[]>(Graph->edges_size());
-    for (const auto &N : Graph->nodes()) {
+    for (const Node &N : Graph->nodes()) {
       Nodes[Graph->getNodeIndex(N)] = Graph->getEdgeIndex(*N.edges_begin());
     }
     Nodes[Graph->nodes_size()] = Graph->edges_size(); // terminator node
-    for (const auto &E : Graph->edges()) {
+    for (const Edge &E : Graph->edges()) {
       Edges[Graph->getEdgeIndex(E)] = Graph->getNodeIndex(*E.getDest());
       EdgeValues[Graph->getEdgeIndex(E)] = E.getValue();
     }
@@ -651,74 +652,67 @@ int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin(
     LLVM_DEBUG(dbgs() << "Inserting LFENCEs... Done\n");
     LLVM_DEBUG(dbgs() << "Inserted " << FencesInserted << " fences\n");
 
-    Graph = GraphBuilder::trim(*Graph, MachineGadgetGraph::NodeSet{*Graph},
-                               CutEdges);
+    Graph = GraphBuilder::trim(*Graph, NodeSet{*Graph}, CutEdges);
   } while (true);
 
   return FencesInserted;
 }
 
-int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithGreedyHeuristic(
+int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithHeuristic(
     MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const {
-  LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n");
-  Graph = trimMitigatedEdges(std::move(Graph));
-  LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n");
+  // If `MF` does not have any fences, then no gadgets would have been
+  // mitigated at this point.
+  if (Graph->NumFences > 0) {
+    LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n");
+    Graph = trimMitigatedEdges(std::move(Graph));
+    LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n");
+  }
+
   if (Graph->NumGadgets == 0)
     return 0;
 
   LLVM_DEBUG(dbgs() << "Cutting edges...\n");
-  MachineGadgetGraph::NodeSet ElimNodes{*Graph}, GadgetSinks{*Graph};
-  MachineGadgetGraph::EdgeSet ElimEdges{*Graph}, CutEdges{*Graph};
-  auto IsCFGEdge = [&ElimEdges, &CutEdges](const MachineGadgetGraph::Edge &E) {
-    return !ElimEdges.contains(E) && !CutEdges.contains(E) &&
-           MachineGadgetGraph::isCFGEdge(E);
-  };
-  auto IsGadgetEdge = [&ElimEdges,
-                       &CutEdges](const MachineGadgetGraph::Edge &E) {
-    return !ElimEdges.contains(E) && !CutEdges.contains(E) &&
-           MachineGadgetGraph::isGadgetEdge(E);
-  };
-
-  // FIXME: this is O(E^2), we could probably do better.
-  do {
-    // Find the cheapest CFG edge that will eliminate a gadget (by being
-    // egress from a SOURCE node or ingress to a SINK node), and cut it.
-    const MachineGadgetGraph::Edge *CheapestSoFar = nullptr;
-
-    // First, collect all gadget source and sink nodes.
-    MachineGadgetGraph::NodeSet GadgetSources{*Graph}, GadgetSinks{*Graph};
-    for (const auto &N : Graph->nodes()) {
-      if (ElimNodes.contains(N))
+  EdgeSet CutEdges{*Graph};
+
+  // Begin by collecting all ingress CFG edges for each node
+  DenseMap<const Node *, SmallVector<const Edge *, 2>> IngressEdgeMap;
+  for (const Edge &E : Graph->edges())
+    if (MachineGadgetGraph::isCFGEdge(E))
+      IngressEdgeMap[E.getDest()].push_back(&E);
+
+  // For each gadget edge, make cuts that guarantee the gadget will be
+  // mitigated. A computationally efficient way to achieve this is to either:
+  // (a) cut all egress CFG edges from the gadget source, or
+  // (b) cut all ingress CFG edges to the gadget sink.
+  //
+  // Moreover, the algorithm tries not to make a cut into a loop by preferring
+  // to make a (b)-type cut if the gadget source resides at a greater loop depth
+  // than the gadget sink, or an (a)-type cut otherwise.
+  for (const Node &N : Graph->nodes()) {
+    for (const Edge &E : N.edges()) {
+      if (!MachineGadgetGraph::isGadgetEdge(E))
         continue;
-      for (const auto &E : N.edges()) {
-        if (IsGadgetEdge(E)) {
-          GadgetSources.insert(N);
-          GadgetSinks.insert(*E.getDest());
-        }
-      }
-    }
 
-    // Next, look for the cheapest CFG edge which, when cut, is guaranteed to
-    // mitigate at least one gadget by either:
-    // (a) being egress from a gadget source, or
-    // (b) being ingress to a gadget sink.
-    for (const auto &N : Graph->nodes()) {
-      if (ElimNodes.contains(N))
-        continue;
-      for (const auto &E : N.edges()) {
-        if (IsCFGEdge(E)) {
-          if (GadgetSources.contains(N) || GadgetSinks.contains(*E.getDest())) {
-            if (!CheapestSoFar || E.getValue() < CheapestSoFar->getValue())
-              CheapestSoFar = &E;
-          }
-        }
-      }
+      SmallVector<const Edge *, 2> EgressEdges;
+      SmallVector<const Edge *, 2> &IngressEdges = IngressEdgeMap[E.getDest()];
+      for (const Edge &EgressEdge : N.edges())
+        if (MachineGadgetGraph::isCFGEdge(EgressEdge))
+          EgressEdges.push_back(&EgressEdge);
+
+      int EgressCutCost = 0, IngressCutCost = 0;
+      for (const Edge *EgressEdge : EgressEdges)
+        if (!CutEdges.contains(*EgressEdge))
+          EgressCutCost += EgressEdge->getValue();
+      for (const Edge *IngressEdge : IngressEdges)
+        if (!CutEdges.contains(*IngressEdge))
+          IngressCutCost += IngressEdge->getValue();
+
+      auto &EdgesToCut =
+          IngressCutCost < EgressCutCost ? IngressEdges : EgressEdges;
+      for (const Edge *E : EdgesToCut)
+        CutEdges.insert(*E);
     }
-
-    assert(CheapestSoFar && "Failed to cut an edge");
-    CutEdges.insert(*CheapestSoFar);
-    ElimEdges.insert(*CheapestSoFar);
-  } while (elimMitigatedEdgesAndNodes(*Graph, ElimEdges, ElimNodes));
+  }
   LLVM_DEBUG(dbgs() << "Cutting edges... Done\n");
   LLVM_DEBUG(dbgs() << "Cut " << CutEdges.count() << " edges\n");
 
@@ -734,8 +728,8 @@ int X86LoadValueInjectionLoadHardeningPass::insertFences(
     MachineFunction &MF, MachineGadgetGraph &G,
     EdgeSet &CutEdges /* in, out */) const {
   int FencesInserted = 0;
-  for (const auto &N : G.nodes()) {
-    for (const auto &E : N.edges()) {
+  for (const Node &N : G.nodes()) {
+    for (const Edge &E : N.edges()) {
       if (CutEdges.contains(E)) {
         MachineInstr *MI = N.getValue(), *Prev;
         MachineBasicBlock *MBB;                  // Insert an LFENCE in this MBB
@@ -751,7 +745,7 @@ int X86LoadValueInjectionLoadHardeningPass::insertFences(
           Prev = MI->getPrevNode();
           // Remove all egress CFG edges from this branch because the inserted
           // LFENCE prevents gadgets from crossing the branch.
-          for (const auto &E : N.edges()) {
+          for (const Edge &E : N.edges()) {
             if (MachineGadgetGraph::isCFGEdge(E))
               CutEdges.insert(E);
           }


        


More information about the llvm-commits mailing list