[llvm] r362857 - Factor out SelectionDAG's switch analysis and lowering into a separate component.

Amara Emerson via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 7 17:05:18 PDT 2019


Author: aemerson
Date: Fri Jun  7 17:05:17 2019
New Revision: 362857

URL: http://llvm.org/viewvc/llvm-project?rev=362857&view=rev
Log:
Factor out SelectionDAG's switch analysis and lowering into a separate component.

In order for GlobalISel to re-use the significant amount of analysis and
optimization code in SDAG's switch lowering, we first have to extract it and
create an interface to be used by both frameworks.

No test changes as it's NFC.

Differential Revision: https://reviews.llvm.org/D62745

Added:
    llvm/trunk/include/llvm/CodeGen/SwitchLoweringUtils.h
    llvm/trunk/lib/CodeGen/SwitchLoweringUtils.cpp
Modified:
    llvm/trunk/lib/CodeGen/CMakeLists.txt
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp

Added: llvm/trunk/include/llvm/CodeGen/SwitchLoweringUtils.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/SwitchLoweringUtils.h?rev=362857&view=auto
==============================================================================
--- llvm/trunk/include/llvm/CodeGen/SwitchLoweringUtils.h (added)
+++ llvm/trunk/include/llvm/CodeGen/SwitchLoweringUtils.h Fri Jun  7 17:05:17 2019
@@ -0,0 +1,275 @@
+//===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
+#define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/SelectionDAGNodes.h"
+#include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/Support/BranchProbability.h"
+
+namespace llvm {
+
+class FunctionLoweringInfo;
+class MachineBasicBlock;
+
+namespace SwitchCG {
+
+enum CaseClusterKind {
+  /// A cluster of adjacent case labels with the same destination, or just one
+  /// case.
+  CC_Range,
+  /// A cluster of cases suitable for jump table lowering.
+  CC_JumpTable,
+  /// A cluster of cases suitable for bit test lowering.
+  CC_BitTests
+};
+
+/// A cluster of case labels.
+struct CaseCluster {
+  CaseClusterKind Kind;
+  const ConstantInt *Low, *High;
+  union {
+    MachineBasicBlock *MBB;
+    unsigned JTCasesIndex;
+    unsigned BTCasesIndex;
+  };
+  BranchProbability Prob;
+
+  static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
+                           MachineBasicBlock *MBB, BranchProbability Prob) {
+    CaseCluster C;
+    C.Kind = CC_Range;
+    C.Low = Low;
+    C.High = High;
+    C.MBB = MBB;
+    C.Prob = Prob;
+    return C;
+  }
+
+  static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High,
+                               unsigned JTCasesIndex, BranchProbability Prob) {
+    CaseCluster C;
+    C.Kind = CC_JumpTable;
+    C.Low = Low;
+    C.High = High;
+    C.JTCasesIndex = JTCasesIndex;
+    C.Prob = Prob;
+    return C;
+  }
+
+  static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
+                              unsigned BTCasesIndex, BranchProbability Prob) {
+    CaseCluster C;
+    C.Kind = CC_BitTests;
+    C.Low = Low;
+    C.High = High;
+    C.BTCasesIndex = BTCasesIndex;
+    C.Prob = Prob;
+    return C;
+  }
+};
+
+using CaseClusterVector = std::vector<CaseCluster>;
+using CaseClusterIt = CaseClusterVector::iterator;
+
+/// Sort Clusters and merge adjacent cases.
+void sortAndRangeify(CaseClusterVector &Clusters);
+
+struct CaseBits {
+  uint64_t Mask = 0;
+  MachineBasicBlock *BB = nullptr;
+  unsigned Bits = 0;
+  BranchProbability ExtraProb;
+
+  CaseBits() = default;
+  CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits,
+           BranchProbability Prob)
+      : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
+};
+
+using CaseBitsVector = std::vector<CaseBits>;
+
+/// This structure is used to communicate between SelectionDAGBuilder and
+/// SDISel for the code generation of additional basic blocks needed by
+/// multi-case switch statements.
+struct CaseBlock {
+  // The condition code to use for the case block's setcc node.
+  // Besides the integer condition codes, this can also be SETTRUE, in which
+  // case no comparison gets emitted.
+  ISD::CondCode CC;
+
+  // The LHS/MHS/RHS of the comparison to emit.
+  // Emit by default LHS op RHS. MHS is used for range comparisons:
+  // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
+  const Value *CmpLHS, *CmpMHS, *CmpRHS;
+
+  // The block to branch to if the setcc is true/false.
+  MachineBasicBlock *TrueBB, *FalseBB;
+
+  // The block into which to emit the code for the setcc and branches.
+  MachineBasicBlock *ThisBB;
+
+  /// The debug location of the instruction this CaseBlock was
+  /// produced from.
+  SDLoc DL;
+
+  // Branch weights.
+  BranchProbability TrueProb, FalseProb;
+
+  CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
+            const Value *cmpmiddle, MachineBasicBlock *truebb,
+            MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
+            BranchProbability trueprob = BranchProbability::getUnknown(),
+            BranchProbability falseprob = BranchProbability::getUnknown())
+      : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
+        TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
+        TrueProb(trueprob), FalseProb(falseprob) {}
+};
+
+struct JumpTable {
+  /// The virtual register containing the index of the jump table entry
+  /// to jump to.
+  unsigned Reg;
+  /// The JumpTableIndex for this jump table in the function.
+  unsigned JTI;
+  /// The MBB into which to emit the code for the indirect jump.
+  MachineBasicBlock *MBB;
+  /// The MBB of the default bb, which is a successor of the range
+  /// check MBB.  This is when updating PHI nodes in successors.
+  MachineBasicBlock *Default;
+
+  JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D)
+      : Reg(R), JTI(J), MBB(M), Default(D) {}
+};
+struct JumpTableHeader {
+  APInt First;
+  APInt Last;
+  const Value *SValue;
+  MachineBasicBlock *HeaderBB;
+  bool Emitted;
+  bool OmitRangeCheck;
+
+  JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
+                  bool E = false)
+      : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
+        Emitted(E), OmitRangeCheck(false) {}
+};
+using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
+
+struct BitTestCase {
+  uint64_t Mask;
+  MachineBasicBlock *ThisBB;
+  MachineBasicBlock *TargetBB;
+  BranchProbability ExtraProb;
+
+  BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr,
+              BranchProbability Prob)
+      : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
+};
+
+using BitTestInfo = SmallVector<BitTestCase, 3>;
+
+struct BitTestBlock {
+  APInt First;
+  APInt Range;
+  const Value *SValue;
+  unsigned Reg;
+  MVT RegVT;
+  bool Emitted;
+  bool ContiguousRange;
+  MachineBasicBlock *Parent;
+  MachineBasicBlock *Default;
+  BitTestInfo Cases;
+  BranchProbability Prob;
+  BranchProbability DefaultProb;
+
+  BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E,
+               bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
+               BitTestInfo C, BranchProbability Pr)
+      : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
+        RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
+        Cases(std::move(C)), Prob(Pr) {}
+};
+
+/// Return the range of value in [First..Last].
+uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
+                           unsigned Last);
+
+/// Return the number of cases in [First..Last].
+uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
+                              unsigned First, unsigned Last);
+
+struct SwitchWorkListItem {
+  MachineBasicBlock *MBB;
+  CaseClusterIt FirstCluster;
+  CaseClusterIt LastCluster;
+  const ConstantInt *GE;
+  const ConstantInt *LT;
+  BranchProbability DefaultProb;
+};
+using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
+
+class SwitchLowering {
+public:
+  SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {}
+
+  void init(const TargetLowering &tli, const TargetMachine &tm,
+            const DataLayout &dl) {
+    TLI = &tli;
+    TM = &tm;
+    DL = &dl;
+  }
+
+  /// Vector of CaseBlock structures used to communicate SwitchInst code
+  /// generation information.
+  std::vector<CaseBlock> SwitchCases;
+
+  /// Vector of JumpTable structures used to communicate SwitchInst code
+  /// generation information.
+  std::vector<JumpTableBlock> JTCases;
+
+  /// Vector of BitTestBlock structures used to communicate SwitchInst code
+  /// generation information.
+  std::vector<BitTestBlock> BitTestCases;
+
+  void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
+                      MachineBasicBlock *DefaultMBB);
+
+  bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
+                      unsigned Last, const SwitchInst *SI,
+                      MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
+
+
+  void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
+
+  /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
+  /// decides it's not a good idea.
+  bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
+                     const SwitchInst *SI, CaseCluster &BTCluster);
+
+  virtual void addSuccessorWithProb(
+      MachineBasicBlock *Src, MachineBasicBlock *Dst,
+      BranchProbability Prob = BranchProbability::getUnknown()) = 0;
+
+  virtual ~SwitchLowering() = default;
+
+private:
+  const TargetLowering *TLI;
+  const TargetMachine *TM;
+  const DataLayout *DL;
+  FunctionLoweringInfo &FuncInfo;
+};
+
+} // namespace SwitchCG
+} // namespace llvm
+
+#endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
+

Modified: llvm/trunk/lib/CodeGen/CMakeLists.txt
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/CMakeLists.txt?rev=362857&r1=362856&r2=362857&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/CMakeLists.txt (original)
+++ llvm/trunk/lib/CodeGen/CMakeLists.txt Fri Jun  7 17:05:17 2019
@@ -145,6 +145,7 @@ add_llvm_library(LLVMCodeGen
   StackProtector.cpp
   StackSlotColoring.cpp
   SwiftErrorValueTracking.cpp
+  SwitchLoweringUtils.cpp
   TailDuplication.cpp
   TailDuplicator.cpp
   TargetFrameLoweringImpl.cpp

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp?rev=362857&r1=362856&r2=362857&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp Fri Jun  7 17:05:17 2019
@@ -124,6 +124,7 @@
 
 using namespace llvm;
 using namespace PatternMatch;
+using namespace SwitchCG;
 
 #define DEBUG_TYPE "isel"
 
@@ -1013,6 +1014,7 @@ void SelectionDAGBuilder::init(GCFunctio
   DL = &DAG.getDataLayout();
   Context = DAG.getContext();
   LPadToCallSiteMap.clear();
+  SL->init(DAG.getTargetLoweringInfo(), TM, DAG.getDataLayout());
 }
 
 void SelectionDAGBuilder::clear() {
@@ -2053,7 +2055,7 @@ SelectionDAGBuilder::EmitBranchForMerged
 
       CaseBlock CB(Condition, BOp->getOperand(0), BOp->getOperand(1), nullptr,
                    TBB, FBB, CurBB, getCurSDLoc(), TProb, FProb);
-      SwitchCases.push_back(CB);
+      SL->SwitchCases.push_back(CB);
       return;
     }
   }
@@ -2062,7 +2064,7 @@ SelectionDAGBuilder::EmitBranchForMerged
   ISD::CondCode Opc = InvertCond ? ISD::SETNE : ISD::SETEQ;
   CaseBlock CB(Opc, Cond, ConstantInt::getTrue(*DAG.getContext()),
                nullptr, TBB, FBB, CurBB, getCurSDLoc(), TProb, FProb);
-  SwitchCases.push_back(CB);
+  SL->SwitchCases.push_back(CB);
 }
 
 void SelectionDAGBuilder::FindMergedConditions(const Value *Cond,
@@ -2271,27 +2273,27 @@ void SelectionDAGBuilder::visitBr(const
       // If the compares in later blocks need to use values not currently
       // exported from this block, export them now.  This block should always
       // be the first entry.
-      assert(SwitchCases[0].ThisBB == BrMBB && "Unexpected lowering!");
+      assert(SL->SwitchCases[0].ThisBB == BrMBB && "Unexpected lowering!");
 
       // Allow some cases to be rejected.
-      if (ShouldEmitAsBranches(SwitchCases)) {
-        for (unsigned i = 1, e = SwitchCases.size(); i != e; ++i) {
-          ExportFromCurrentBlock(SwitchCases[i].CmpLHS);
-          ExportFromCurrentBlock(SwitchCases[i].CmpRHS);
+      if (ShouldEmitAsBranches(SL->SwitchCases)) {
+        for (unsigned i = 1, e = SL->SwitchCases.size(); i != e; ++i) {
+          ExportFromCurrentBlock(SL->SwitchCases[i].CmpLHS);
+          ExportFromCurrentBlock(SL->SwitchCases[i].CmpRHS);
         }
 
         // Emit the branch for this block.
-        visitSwitchCase(SwitchCases[0], BrMBB);
-        SwitchCases.erase(SwitchCases.begin());
+        visitSwitchCase(SL->SwitchCases[0], BrMBB);
+        SL->SwitchCases.erase(SL->SwitchCases.begin());
         return;
       }
 
       // Okay, we decided not to do this, remove any inserted MBB's and clear
       // SwitchCases.
-      for (unsigned i = 1, e = SwitchCases.size(); i != e; ++i)
-        FuncInfo.MF->erase(SwitchCases[i].ThisBB);
+      for (unsigned i = 1, e = SL->SwitchCases.size(); i != e; ++i)
+        FuncInfo.MF->erase(SL->SwitchCases[i].ThisBB);
 
-      SwitchCases.clear();
+      SL->SwitchCases.clear();
     }
   }
 
@@ -2399,7 +2401,7 @@ void SelectionDAGBuilder::visitSwitchCas
 }
 
 /// visitJumpTable - Emit JumpTable node in the current MBB
-void SelectionDAGBuilder::visitJumpTable(JumpTable &JT) {
+void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) {
   // Emit the code for the jump table
   assert(JT.Reg != -1U && "Should lower JT Header first!");
   EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
@@ -2414,7 +2416,7 @@ void SelectionDAGBuilder::visitJumpTable
 
 /// visitJumpTableHeader - This function emits necessary code to produce index
 /// in the JumpTable from switch case.
-void SelectionDAGBuilder::visitJumpTableHeader(JumpTable &JT,
+void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT,
                                                JumpTableHeader &JTH,
                                                MachineBasicBlock *SwitchBB) {
   SDLoc dl = getCurSDLoc();
@@ -2896,49 +2898,17 @@ void SelectionDAGBuilder::visitLandingPa
   setValue(&LP, Res);
 }
 
-void SelectionDAGBuilder::sortAndRangeify(CaseClusterVector &Clusters) {
-#ifndef NDEBUG
-  for (const CaseCluster &CC : Clusters)
-    assert(CC.Low == CC.High && "Input clusters must be single-case");
-#endif
-
-  llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
-    return a.Low->getValue().slt(b.Low->getValue());
-  });
-
-  // Merge adjacent clusters with the same destination.
-  const unsigned N = Clusters.size();
-  unsigned DstIndex = 0;
-  for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
-    CaseCluster &CC = Clusters[SrcIndex];
-    const ConstantInt *CaseVal = CC.Low;
-    MachineBasicBlock *Succ = CC.MBB;
-
-    if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
-        (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
-      // If this case has the same successor and is a neighbour, merge it into
-      // the previous cluster.
-      Clusters[DstIndex - 1].High = CaseVal;
-      Clusters[DstIndex - 1].Prob += CC.Prob;
-    } else {
-      std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
-                   sizeof(Clusters[SrcIndex]));
-    }
-  }
-  Clusters.resize(DstIndex);
-}
-
 void SelectionDAGBuilder::UpdateSplitBlock(MachineBasicBlock *First,
                                            MachineBasicBlock *Last) {
   // Update JTCases.
-  for (unsigned i = 0, e = JTCases.size(); i != e; ++i)
-    if (JTCases[i].first.HeaderBB == First)
-      JTCases[i].first.HeaderBB = Last;
+  for (unsigned i = 0, e = SL->JTCases.size(); i != e; ++i)
+    if (SL->JTCases[i].first.HeaderBB == First)
+      SL->JTCases[i].first.HeaderBB = Last;
 
   // Update BitTestCases.
-  for (unsigned i = 0, e = BitTestCases.size(); i != e; ++i)
-    if (BitTestCases[i].Parent == First)
-      BitTestCases[i].Parent = Last;
+  for (unsigned i = 0, e = SL->BitTestCases.size(); i != e; ++i)
+    if (SL->BitTestCases[i].Parent == First)
+      SL->BitTestCases[i].Parent = Last;
 }
 
 void SelectionDAGBuilder::visitIndirectBr(const IndirectBrInst &I) {
@@ -9943,450 +9913,6 @@ void SelectionDAGBuilder::updateDAGForMa
     HasTailCall = true;
 }
 
-uint64_t
-SelectionDAGBuilder::getJumpTableRange(const CaseClusterVector &Clusters,
-                                       unsigned First, unsigned Last) const {
-  assert(Last >= First);
-  const APInt &LowCase = Clusters[First].Low->getValue();
-  const APInt &HighCase = Clusters[Last].High->getValue();
-  assert(LowCase.getBitWidth() == HighCase.getBitWidth());
-
-  // FIXME: A range of consecutive cases has 100% density, but only requires one
-  // comparison to lower. We should discriminate against such consecutive ranges
-  // in jump tables.
-
-  return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
-}
-
-uint64_t SelectionDAGBuilder::getJumpTableNumCases(
-    const SmallVectorImpl<unsigned> &TotalCases, unsigned First,
-    unsigned Last) const {
-  assert(Last >= First);
-  assert(TotalCases[Last] >= TotalCases[First]);
-  uint64_t NumCases =
-      TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
-  return NumCases;
-}
-
-bool SelectionDAGBuilder::buildJumpTable(const CaseClusterVector &Clusters,
-                                         unsigned First, unsigned Last,
-                                         const SwitchInst *SI,
-                                         MachineBasicBlock *DefaultMBB,
-                                         CaseCluster &JTCluster) {
-  assert(First <= Last);
-
-  auto Prob = BranchProbability::getZero();
-  unsigned NumCmps = 0;
-  std::vector<MachineBasicBlock*> Table;
-  DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
-
-  // Initialize probabilities in JTProbs.
-  for (unsigned I = First; I <= Last; ++I)
-    JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
-
-  for (unsigned I = First; I <= Last; ++I) {
-    assert(Clusters[I].Kind == CC_Range);
-    Prob += Clusters[I].Prob;
-    const APInt &Low = Clusters[I].Low->getValue();
-    const APInt &High = Clusters[I].High->getValue();
-    NumCmps += (Low == High) ? 1 : 2;
-    if (I != First) {
-      // Fill the gap between this and the previous cluster.
-      const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
-      assert(PreviousHigh.slt(Low));
-      uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
-      for (uint64_t J = 0; J < Gap; J++)
-        Table.push_back(DefaultMBB);
-    }
-    uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
-    for (uint64_t J = 0; J < ClusterSize; ++J)
-      Table.push_back(Clusters[I].MBB);
-    JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
-  }
-
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  unsigned NumDests = JTProbs.size();
-  if (TLI.isSuitableForBitTests(
-          NumDests, NumCmps, Clusters[First].Low->getValue(),
-          Clusters[Last].High->getValue(), DAG.getDataLayout())) {
-    // Clusters[First..Last] should be lowered as bit tests instead.
-    return false;
-  }
-
-  // Create the MBB that will load from and jump through the table.
-  // Note: We create it here, but it's not inserted into the function yet.
-  MachineFunction *CurMF = FuncInfo.MF;
-  MachineBasicBlock *JumpTableMBB =
-      CurMF->CreateMachineBasicBlock(SI->getParent());
-
-  // Add successors. Note: use table order for determinism.
-  SmallPtrSet<MachineBasicBlock *, 8> Done;
-  for (MachineBasicBlock *Succ : Table) {
-    if (Done.count(Succ))
-      continue;
-    addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
-    Done.insert(Succ);
-  }
-  JumpTableMBB->normalizeSuccProbs();
-
-  unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI.getJumpTableEncoding())
-                     ->createJumpTableIndex(Table);
-
-  // Set up the jump table info.
-  JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
-  JumpTableHeader JTH(Clusters[First].Low->getValue(),
-                      Clusters[Last].High->getValue(), SI->getCondition(),
-                      nullptr, false);
-  JTCases.emplace_back(std::move(JTH), std::move(JT));
-
-  JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
-                                     JTCases.size() - 1, Prob);
-  return true;
-}
-
-void SelectionDAGBuilder::findJumpTables(CaseClusterVector &Clusters,
-                                         const SwitchInst *SI,
-                                         MachineBasicBlock *DefaultMBB) {
-#ifndef NDEBUG
-  // Clusters must be non-empty, sorted, and only contain Range clusters.
-  assert(!Clusters.empty());
-  for (CaseCluster &C : Clusters)
-    assert(C.Kind == CC_Range);
-  for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
-    assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
-#endif
-
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  if (!TLI.areJTsAllowed(SI->getParent()->getParent()))
-    return;
-
-  const int64_t N = Clusters.size();
-  const unsigned MinJumpTableEntries = TLI.getMinimumJumpTableEntries();
-  const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
-
-  if (N < 2 || N < MinJumpTableEntries)
-    return;
-
-  // TotalCases[i]: Total nbr of cases in Clusters[0..i].
-  SmallVector<unsigned, 8> TotalCases(N);
-  for (unsigned i = 0; i < N; ++i) {
-    const APInt &Hi = Clusters[i].High->getValue();
-    const APInt &Lo = Clusters[i].Low->getValue();
-    TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
-    if (i != 0)
-      TotalCases[i] += TotalCases[i - 1];
-  }
-
-  // Cheap case: the whole range may be suitable for jump table.
-  uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
-  uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
-  assert(NumCases < UINT64_MAX / 100);
-  assert(Range >= NumCases);
-  if (TLI.isSuitableForJumpTable(SI, NumCases, Range)) {
-    CaseCluster JTCluster;
-    if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
-      Clusters[0] = JTCluster;
-      Clusters.resize(1);
-      return;
-    }
-  }
-
-  // The algorithm below is not suitable for -O0.
-  if (TM.getOptLevel() == CodeGenOpt::None)
-    return;
-
-  // Split Clusters into minimum number of dense partitions. The algorithm uses
-  // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
-  // for the Case Statement'" (1994), but builds the MinPartitions array in
-  // reverse order to make it easier to reconstruct the partitions in ascending
-  // order. In the choice between two optimal partitionings, it picks the one
-  // which yields more jump tables.
-
-  // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
-  SmallVector<unsigned, 8> MinPartitions(N);
-  // LastElement[i] is the last element of the partition starting at i.
-  SmallVector<unsigned, 8> LastElement(N);
-  // PartitionsScore[i] is used to break ties when choosing between two
-  // partitionings resulting in the same number of partitions.
-  SmallVector<unsigned, 8> PartitionsScore(N);
-  // For PartitionsScore, a small number of comparisons is considered as good as
-  // a jump table and a single comparison is considered better than a jump
-  // table.
-  enum PartitionScores : unsigned {
-    NoTable = 0,
-    Table = 1,
-    FewCases = 1,
-    SingleCase = 2
-  };
-
-  // Base case: There is only one way to partition Clusters[N-1].
-  MinPartitions[N - 1] = 1;
-  LastElement[N - 1] = N - 1;
-  PartitionsScore[N - 1] = PartitionScores::SingleCase;
-
-  // Note: loop indexes are signed to avoid underflow.
-  for (int64_t i = N - 2; i >= 0; i--) {
-    // Find optimal partitioning of Clusters[i..N-1].
-    // Baseline: Put Clusters[i] into a partition on its own.
-    MinPartitions[i] = MinPartitions[i + 1] + 1;
-    LastElement[i] = i;
-    PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
-
-    // Search for a solution that results in fewer partitions.
-    for (int64_t j = N - 1; j > i; j--) {
-      // Try building a partition from Clusters[i..j].
-      uint64_t Range = getJumpTableRange(Clusters, i, j);
-      uint64_t NumCases = getJumpTableNumCases(TotalCases, i, j);
-      assert(NumCases < UINT64_MAX / 100);
-      assert(Range >= NumCases);
-      if (TLI.isSuitableForJumpTable(SI, NumCases, Range)) {
-        unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
-        unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
-        int64_t NumEntries = j - i + 1;
-
-        if (NumEntries == 1)
-          Score += PartitionScores::SingleCase;
-        else if (NumEntries <= SmallNumberOfEntries)
-          Score += PartitionScores::FewCases;
-        else if (NumEntries >= MinJumpTableEntries)
-          Score += PartitionScores::Table;
-
-        // If this leads to fewer partitions, or to the same number of
-        // partitions with better score, it is a better partitioning.
-        if (NumPartitions < MinPartitions[i] ||
-            (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
-          MinPartitions[i] = NumPartitions;
-          LastElement[i] = j;
-          PartitionsScore[i] = Score;
-        }
-      }
-    }
-  }
-
-  // Iterate over the partitions, replacing some with jump tables in-place.
-  unsigned DstIndex = 0;
-  for (unsigned First = 0, Last; First < N; First = Last + 1) {
-    Last = LastElement[First];
-    assert(Last >= First);
-    assert(DstIndex <= First);
-    unsigned NumClusters = Last - First + 1;
-
-    CaseCluster JTCluster;
-    if (NumClusters >= MinJumpTableEntries &&
-        buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
-      Clusters[DstIndex++] = JTCluster;
-    } else {
-      for (unsigned I = First; I <= Last; ++I)
-        std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
-    }
-  }
-  Clusters.resize(DstIndex);
-}
-
-bool SelectionDAGBuilder::buildBitTests(CaseClusterVector &Clusters,
-                                        unsigned First, unsigned Last,
-                                        const SwitchInst *SI,
-                                        CaseCluster &BTCluster) {
-  assert(First <= Last);
-  if (First == Last)
-    return false;
-
-  BitVector Dests(FuncInfo.MF->getNumBlockIDs());
-  unsigned NumCmps = 0;
-  for (int64_t I = First; I <= Last; ++I) {
-    assert(Clusters[I].Kind == CC_Range);
-    Dests.set(Clusters[I].MBB->getNumber());
-    NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
-  }
-  unsigned NumDests = Dests.count();
-
-  APInt Low = Clusters[First].Low->getValue();
-  APInt High = Clusters[Last].High->getValue();
-  assert(Low.slt(High));
-
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  const DataLayout &DL = DAG.getDataLayout();
-  if (!TLI.isSuitableForBitTests(NumDests, NumCmps, Low, High, DL))
-    return false;
-
-  APInt LowBound;
-  APInt CmpRange;
-
-  const int BitWidth = TLI.getPointerTy(DL).getSizeInBits();
-  assert(TLI.rangeFitsInWord(Low, High, DL) &&
-         "Case range must fit in bit mask!");
-
-  // Check if the clusters cover a contiguous range such that no value in the
-  // range will jump to the default statement.
-  bool ContiguousRange = true;
-  for (int64_t I = First + 1; I <= Last; ++I) {
-    if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
-      ContiguousRange = false;
-      break;
-    }
-  }
-
-  if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
-    // Optimize the case where all the case values fit in a word without having
-    // to subtract minValue. In this case, we can optimize away the subtraction.
-    LowBound = APInt::getNullValue(Low.getBitWidth());
-    CmpRange = High;
-    ContiguousRange = false;
-  } else {
-    LowBound = Low;
-    CmpRange = High - Low;
-  }
-
-  CaseBitsVector CBV;
-  auto TotalProb = BranchProbability::getZero();
-  for (unsigned i = First; i <= Last; ++i) {
-    // Find the CaseBits for this destination.
-    unsigned j;
-    for (j = 0; j < CBV.size(); ++j)
-      if (CBV[j].BB == Clusters[i].MBB)
-        break;
-    if (j == CBV.size())
-      CBV.push_back(
-          CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
-    CaseBits *CB = &CBV[j];
-
-    // Update Mask, Bits and ExtraProb.
-    uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
-    uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
-    assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
-    CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
-    CB->Bits += Hi - Lo + 1;
-    CB->ExtraProb += Clusters[i].Prob;
-    TotalProb += Clusters[i].Prob;
-  }
-
-  BitTestInfo BTI;
-  llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
-    // Sort by probability first, number of bits second, bit mask third.
-    if (a.ExtraProb != b.ExtraProb)
-      return a.ExtraProb > b.ExtraProb;
-    if (a.Bits != b.Bits)
-      return a.Bits > b.Bits;
-    return a.Mask < b.Mask;
-  });
-
-  for (auto &CB : CBV) {
-    MachineBasicBlock *BitTestBB =
-        FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
-    BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
-  }
-  BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
-                            SI->getCondition(), -1U, MVT::Other, false,
-                            ContiguousRange, nullptr, nullptr, std::move(BTI),
-                            TotalProb);
-
-  BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
-                                    BitTestCases.size() - 1, TotalProb);
-  return true;
-}
-
-void SelectionDAGBuilder::findBitTestClusters(CaseClusterVector &Clusters,
-                                              const SwitchInst *SI) {
-// Partition Clusters into as few subsets as possible, where each subset has a
-// range that fits in a machine word and has <= 3 unique destinations.
-
-#ifndef NDEBUG
-  // Clusters must be sorted and contain Range or JumpTable clusters.
-  assert(!Clusters.empty());
-  assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
-  for (const CaseCluster &C : Clusters)
-    assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
-  for (unsigned i = 1; i < Clusters.size(); ++i)
-    assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
-#endif
-
-  // The algorithm below is not suitable for -O0.
-  if (TM.getOptLevel() == CodeGenOpt::None)
-    return;
-
-  // If target does not have legal shift left, do not emit bit tests at all.
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  const DataLayout &DL = DAG.getDataLayout();
-
-  EVT PTy = TLI.getPointerTy(DL);
-  if (!TLI.isOperationLegal(ISD::SHL, PTy))
-    return;
-
-  int BitWidth = PTy.getSizeInBits();
-  const int64_t N = Clusters.size();
-
-  // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
-  SmallVector<unsigned, 8> MinPartitions(N);
-  // LastElement[i] is the last element of the partition starting at i.
-  SmallVector<unsigned, 8> LastElement(N);
-
-  // FIXME: This might not be the best algorithm for finding bit test clusters.
-
-  // Base case: There is only one way to partition Clusters[N-1].
-  MinPartitions[N - 1] = 1;
-  LastElement[N - 1] = N - 1;
-
-  // Note: loop indexes are signed to avoid underflow.
-  for (int64_t i = N - 2; i >= 0; --i) {
-    // Find optimal partitioning of Clusters[i..N-1].
-    // Baseline: Put Clusters[i] into a partition on its own.
-    MinPartitions[i] = MinPartitions[i + 1] + 1;
-    LastElement[i] = i;
-
-    // Search for a solution that results in fewer partitions.
-    // Note: the search is limited by BitWidth, reducing time complexity.
-    for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
-      // Try building a partition from Clusters[i..j].
-
-      // Check the range.
-      if (!TLI.rangeFitsInWord(Clusters[i].Low->getValue(),
-                               Clusters[j].High->getValue(), DL))
-        continue;
-
-      // Check nbr of destinations and cluster types.
-      // FIXME: This works, but doesn't seem very efficient.
-      bool RangesOnly = true;
-      BitVector Dests(FuncInfo.MF->getNumBlockIDs());
-      for (int64_t k = i; k <= j; k++) {
-        if (Clusters[k].Kind != CC_Range) {
-          RangesOnly = false;
-          break;
-        }
-        Dests.set(Clusters[k].MBB->getNumber());
-      }
-      if (!RangesOnly || Dests.count() > 3)
-        break;
-
-      // Check if it's a better partition.
-      unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
-      if (NumPartitions < MinPartitions[i]) {
-        // Found a better partition.
-        MinPartitions[i] = NumPartitions;
-        LastElement[i] = j;
-      }
-    }
-  }
-
-  // Iterate over the partitions, replacing with bit-test clusters in-place.
-  unsigned DstIndex = 0;
-  for (unsigned First = 0, Last; First < N; First = Last + 1) {
-    Last = LastElement[First];
-    assert(First <= Last);
-    assert(DstIndex <= First);
-
-    CaseCluster BitTestCluster;
-    if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
-      Clusters[DstIndex++] = BitTestCluster;
-    } else {
-      size_t NumClusters = Last - First + 1;
-      std::memmove(&Clusters[DstIndex], &Clusters[First],
-                   sizeof(Clusters[0]) * NumClusters);
-      DstIndex += NumClusters;
-    }
-  }
-  Clusters.resize(DstIndex);
-}
-
 void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
                                         MachineBasicBlock *SwitchMBB,
                                         MachineBasicBlock *DefaultMBB) {
@@ -10506,8 +10032,8 @@ void SelectionDAGBuilder::lowerWorkItem(
     switch (I->Kind) {
       case CC_JumpTable: {
         // FIXME: Optimize away range check based on pivot comparisons.
-        JumpTableHeader *JTH = &JTCases[I->JTCasesIndex].first;
-        JumpTable *JT = &JTCases[I->JTCasesIndex].second;
+        JumpTableHeader *JTH = &SL->JTCases[I->JTCasesIndex].first;
+        SwitchCG::JumpTable *JT = &SL->JTCases[I->JTCasesIndex].second;
 
         // The jump block hasn't been inserted yet; insert it here.
         MachineBasicBlock *JumpMBB = JT->MBB;
@@ -10557,7 +10083,7 @@ void SelectionDAGBuilder::lowerWorkItem(
         // FIXME: If Fallthrough is unreachable, skip the range check.
 
         // FIXME: Optimize away range check based on pivot comparisons.
-        BitTestBlock *BTB = &BitTestCases[I->BTCasesIndex];
+        BitTestBlock *BTB = &SL->BitTestCases[I->BTCasesIndex];
 
         // The bit test blocks haven't been inserted yet; insert them here.
         for (BitTestCase &BTC : BTB->Cases)
@@ -10611,7 +10137,7 @@ void SelectionDAGBuilder::lowerWorkItem(
         if (CurMBB == SwitchMBB)
           visitSwitchCase(CB, SwitchMBB);
         else
-          SwitchCases.push_back(CB);
+          SL->SwitchCases.push_back(CB);
 
         break;
       }
@@ -10762,7 +10288,7 @@ void SelectionDAGBuilder::splitWorkItem(
   if (W.MBB == SwitchMBB)
     visitSwitchCase(CB, SwitchMBB);
   else
-    SwitchCases.push_back(CB);
+    SL->SwitchCases.push_back(CB);
 }
 
 // Scale CaseProb after peeling a case with the probablity of PeeledCaseProb
@@ -10874,8 +10400,8 @@ void SelectionDAGBuilder::visitSwitch(co
     return;
   }
 
-  findJumpTables(Clusters, &SI, DefaultMBB);
-  findBitTestClusters(Clusters, &SI);
+  SL->findJumpTables(Clusters, &SI, DefaultMBB);
+  SL->findBitTestClusters(Clusters, &SI);
 
   LLVM_DEBUG({
     dbgs() << "Case clusters: ";

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h?rev=362857&r1=362856&r2=362857&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h Fri Jun  7 17:05:17 2019
@@ -23,6 +23,7 @@
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
+#include "llvm/CodeGen/SwitchLoweringUtils.h"
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/IR/CallSite.h"
@@ -146,239 +147,27 @@ private:
   /// create.
   unsigned SDNodeOrder;
 
-  enum CaseClusterKind {
-    /// A cluster of adjacent case labels with the same destination, or just one
-    /// case.
-    CC_Range,
-    /// A cluster of cases suitable for jump table lowering.
-    CC_JumpTable,
-    /// A cluster of cases suitable for bit test lowering.
-    CC_BitTests
-  };
-
-  /// A cluster of case labels.
-  struct CaseCluster {
-    CaseClusterKind Kind;
-    const ConstantInt *Low, *High;
-    union {
-      MachineBasicBlock *MBB;
-      unsigned JTCasesIndex;
-      unsigned BTCasesIndex;
-    };
-    BranchProbability Prob;
-
-    static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
-                             MachineBasicBlock *MBB, BranchProbability Prob) {
-      CaseCluster C;
-      C.Kind = CC_Range;
-      C.Low = Low;
-      C.High = High;
-      C.MBB = MBB;
-      C.Prob = Prob;
-      return C;
-    }
-
-    static CaseCluster jumpTable(const ConstantInt *Low,
-                                 const ConstantInt *High, unsigned JTCasesIndex,
-                                 BranchProbability Prob) {
-      CaseCluster C;
-      C.Kind = CC_JumpTable;
-      C.Low = Low;
-      C.High = High;
-      C.JTCasesIndex = JTCasesIndex;
-      C.Prob = Prob;
-      return C;
-    }
-
-    static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
-                                unsigned BTCasesIndex, BranchProbability Prob) {
-      CaseCluster C;
-      C.Kind = CC_BitTests;
-      C.Low = Low;
-      C.High = High;
-      C.BTCasesIndex = BTCasesIndex;
-      C.Prob = Prob;
-      return C;
-    }
-  };
-
-  using CaseClusterVector = std::vector<CaseCluster>;
-  using CaseClusterIt = CaseClusterVector::iterator;
-
-  struct CaseBits {
-    uint64_t Mask = 0;
-    MachineBasicBlock* BB = nullptr;
-    unsigned Bits = 0;
-    BranchProbability ExtraProb;
-
-    CaseBits() = default;
-    CaseBits(uint64_t mask, MachineBasicBlock* bb, unsigned bits,
-             BranchProbability Prob):
-      Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
-  };
-
-  using CaseBitsVector = std::vector<CaseBits>;
-
-  /// Sort Clusters and merge adjacent cases.
-  void sortAndRangeify(CaseClusterVector &Clusters);
-
-  /// This structure is used to communicate between SelectionDAGBuilder and
-  /// SDISel for the code generation of additional basic blocks needed by
-  /// multi-case switch statements.
-  struct CaseBlock {
-    // The condition code to use for the case block's setcc node.
-    // Besides the integer condition codes, this can also be SETTRUE, in which
-    // case no comparison gets emitted.
-    ISD::CondCode CC;
-
-    // The LHS/MHS/RHS of the comparison to emit.
-    // Emit by default LHS op RHS. MHS is used for range comparisons:
-    // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
-    const Value *CmpLHS, *CmpMHS, *CmpRHS;
-
-    // The block to branch to if the setcc is true/false.
-    MachineBasicBlock *TrueBB, *FalseBB;
-
-    // The block into which to emit the code for the setcc and branches.
-    MachineBasicBlock *ThisBB;
-
-    /// The debug location of the instruction this CaseBlock was
-    /// produced from.
-    SDLoc DL;
-
-    // Branch weights.
-    BranchProbability TrueProb, FalseProb;
-
-    CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
-              const Value *cmpmiddle, MachineBasicBlock *truebb,
-              MachineBasicBlock *falsebb, MachineBasicBlock *me,
-              SDLoc dl,
-              BranchProbability trueprob = BranchProbability::getUnknown(),
-              BranchProbability falseprob = BranchProbability::getUnknown())
-        : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
-          TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
-          TrueProb(trueprob), FalseProb(falseprob) {}
-  };
-
-  struct JumpTable {
-    /// The virtual register containing the index of the jump table entry
-    /// to jump to.
-    unsigned Reg;
-    /// The JumpTableIndex for this jump table in the function.
-    unsigned JTI;
-    /// The MBB into which to emit the code for the indirect jump.
-    MachineBasicBlock *MBB;
-    /// The MBB of the default bb, which is a successor of the range
-    /// check MBB.  This is when updating PHI nodes in successors.
-    MachineBasicBlock *Default;
-
-    JumpTable(unsigned R, unsigned J, MachineBasicBlock *M,
-              MachineBasicBlock *D): Reg(R), JTI(J), MBB(M), Default(D) {}
-  };
-  struct JumpTableHeader {
-    APInt First;
-    APInt Last;
-    const Value *SValue;
-    MachineBasicBlock *HeaderBB;
-    bool Emitted;
-    bool OmitRangeCheck;
-
-    JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
-                    bool E = false)
-        : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
-          Emitted(E), OmitRangeCheck(false) {}
-  };
-  using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
-
-  struct BitTestCase {
-    uint64_t Mask;
-    MachineBasicBlock *ThisBB;
-    MachineBasicBlock *TargetBB;
-    BranchProbability ExtraProb;
-
-    BitTestCase(uint64_t M, MachineBasicBlock* T, MachineBasicBlock* Tr,
-                BranchProbability Prob):
-      Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
-  };
-
-  using BitTestInfo = SmallVector<BitTestCase, 3>;
-
-  struct BitTestBlock {
-    APInt First;
-    APInt Range;
-    const Value *SValue;
-    unsigned Reg;
-    MVT RegVT;
-    bool Emitted;
-    bool ContiguousRange;
-    MachineBasicBlock *Parent;
-    MachineBasicBlock *Default;
-    BitTestInfo Cases;
-    BranchProbability Prob;
-    BranchProbability DefaultProb;
-
-    BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT,
-                 bool E, bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
-                 BitTestInfo C, BranchProbability Pr)
-        : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
-          RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
-          Cases(std::move(C)), Prob(Pr) {}
-  };
-
-  /// Return the range of value in [First..Last].
-  uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
-                             unsigned Last) const;
-
-  /// Return the number of cases in [First..Last].
-  uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
-                                unsigned First, unsigned Last) const;
-
-  /// Build a jump table cluster from Clusters[First..Last]. Returns false if it
-  /// decides it's not a good idea.
-  bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
-                      unsigned Last, const SwitchInst *SI,
-                      MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
-
-  /// Find clusters of cases suitable for jump table lowering.
-  void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
-                      MachineBasicBlock *DefaultMBB);
-
-  /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
-  /// decides it's not a good idea.
-  bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
-                     const SwitchInst *SI, CaseCluster &BTCluster);
-
-  /// Find clusters of cases suitable for bit test lowering.
-  void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
-
-  struct SwitchWorkListItem {
-    MachineBasicBlock *MBB;
-    CaseClusterIt FirstCluster;
-    CaseClusterIt LastCluster;
-    const ConstantInt *GE;
-    const ConstantInt *LT;
-    BranchProbability DefaultProb;
-  };
-  using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
-
   /// Determine the rank by weight of CC in [First,Last]. If CC has more weight
   /// than each cluster in the range, its rank is 0.
-  static unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First,
-                                  CaseClusterIt Last);
+  unsigned caseClusterRank(const SwitchCG::CaseCluster &CC,
+                           SwitchCG::CaseClusterIt First,
+                           SwitchCG::CaseClusterIt Last);
 
   /// Emit comparison and split W into two subtrees.
-  void splitWorkItem(SwitchWorkList &WorkList, const SwitchWorkListItem &W,
-                     Value *Cond, MachineBasicBlock *SwitchMBB);
+  void splitWorkItem(SwitchCG::SwitchWorkList &WorkList,
+                     const SwitchCG::SwitchWorkListItem &W, Value *Cond,
+                     MachineBasicBlock *SwitchMBB);
 
   /// Lower W.
-  void lowerWorkItem(SwitchWorkListItem W, Value *Cond,
+  void lowerWorkItem(SwitchCG::SwitchWorkListItem W, Value *Cond,
                      MachineBasicBlock *SwitchMBB,
                      MachineBasicBlock *DefaultMBB);
 
   /// Peel the top probability case if it exceeds the threshold
-  MachineBasicBlock *peelDominantCaseCluster(const SwitchInst &SI,
-                                             CaseClusterVector &Clusters,
-                                             BranchProbability &PeeledCaseProb);
+  MachineBasicBlock *
+  peelDominantCaseCluster(const SwitchInst &SI,
+                          SwitchCG::CaseClusterVector &Clusters,
+                          BranchProbability &PeeledCaseProb);
 
   /// A class which encapsulates all of the information needed to generate a
   /// stack protector check and signals to isel via its state being initialized
@@ -591,17 +380,22 @@ public:
   AliasAnalysis *AA = nullptr;
   const TargetLibraryInfo *LibInfo;
 
-  /// Vector of CaseBlock structures used to communicate SwitchInst code
-  /// generation information.
-  std::vector<CaseBlock> SwitchCases;
-
-  /// Vector of JumpTable structures used to communicate SwitchInst code
-  /// generation information.
-  std::vector<JumpTableBlock> JTCases;
-
-  /// Vector of BitTestBlock structures used to communicate SwitchInst code
-  /// generation information.
-  std::vector<BitTestBlock> BitTestCases;
+  class SDAGSwitchLowering : public SwitchCG::SwitchLowering {
+  public:
+    SDAGSwitchLowering(SelectionDAGBuilder *sdb, FunctionLoweringInfo &funcinfo)
+        : SwitchCG::SwitchLowering(funcinfo), SDB(sdb) {}
+
+    virtual void addSuccessorWithProb(
+        MachineBasicBlock *Src, MachineBasicBlock *Dst,
+        BranchProbability Prob = BranchProbability::getUnknown()) override {
+      SDB->addSuccessorWithProb(Src, Dst, Prob);
+    }
+
+  private:
+    SelectionDAGBuilder *SDB;
+  };
+
+  std::unique_ptr<SDAGSwitchLowering> SL;
 
   /// A StackProtectorDescriptor structure used to communicate stack protector
   /// information in between SelectBasicBlock and FinishBasicBlock.
@@ -632,7 +426,8 @@ public:
   SelectionDAGBuilder(SelectionDAG &dag, FunctionLoweringInfo &funcinfo,
                       SwiftErrorValueTracking &swifterror, CodeGenOpt::Level ol)
       : SDNodeOrder(LowestSDNodeOrder), TM(dag.getTarget()), DAG(dag),
-        FuncInfo(funcinfo), SwiftError(swifterror) {}
+        SL(make_unique<SDAGSwitchLowering>(this, funcinfo)), FuncInfo(funcinfo),
+        SwiftError(swifterror) {}
 
   void init(GCFunctionInfo *gfi, AliasAnalysis *AA,
             const TargetLibraryInfo *li);
@@ -738,7 +533,7 @@ public:
                                     MachineBasicBlock *SwitchBB,
                                     BranchProbability TProb, BranchProbability FProb,
                                     bool InvertCond);
-  bool ShouldEmitAsBranches(const std::vector<CaseBlock> &Cases);
+  bool ShouldEmitAsBranches(const std::vector<SwitchCG::CaseBlock> &Cases);
   bool isExportableFromCurrentBlock(const Value *V, const BasicBlock *FromBB);
   void CopyToExportRegsIfNeeded(const Value *V);
   void ExportFromCurrentBlock(const Value *V);
@@ -851,20 +646,18 @@ private:
       BranchProbability Prob = BranchProbability::getUnknown());
 
 public:
-  void visitSwitchCase(CaseBlock &CB,
-                       MachineBasicBlock *SwitchBB);
+  void visitSwitchCase(SwitchCG::CaseBlock &CB, MachineBasicBlock *SwitchBB);
   void visitSPDescriptorParent(StackProtectorDescriptor &SPD,
                                MachineBasicBlock *ParentBB);
   void visitSPDescriptorFailure(StackProtectorDescriptor &SPD);
-  void visitBitTestHeader(BitTestBlock &B, MachineBasicBlock *SwitchBB);
-  void visitBitTestCase(BitTestBlock &BB,
-                        MachineBasicBlock* NextMBB,
-                        BranchProbability BranchProbToNext,
-                        unsigned Reg,
-                        BitTestCase &B,
-                        MachineBasicBlock *SwitchBB);
-  void visitJumpTable(JumpTable &JT);
-  void visitJumpTableHeader(JumpTable &JT, JumpTableHeader &JTH,
+  void visitBitTestHeader(SwitchCG::BitTestBlock &B,
+                          MachineBasicBlock *SwitchBB);
+  void visitBitTestCase(SwitchCG::BitTestBlock &BB, MachineBasicBlock *NextMBB,
+                        BranchProbability BranchProbToNext, unsigned Reg,
+                        SwitchCG::BitTestCase &B, MachineBasicBlock *SwitchBB);
+  void visitJumpTable(SwitchCG::JumpTable &JT);
+  void visitJumpTableHeader(SwitchCG::JumpTable &JT,
+                            SwitchCG::JumpTableHeader &JTH,
                             MachineBasicBlock *SwitchBB);
 
 private:

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp?rev=362857&r1=362856&r2=362857&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp Fri Jun  7 17:05:17 2019
@@ -1740,7 +1740,7 @@ SelectionDAGISel::FinishBasicBlock() {
   }
 
   // Lower each BitTestBlock.
-  for (auto &BTB : SDB->BitTestCases) {
+  for (auto &BTB : SDB->SL->BitTestCases) {
     // Lower header first, if it wasn't already lowered
     if (!BTB.Emitted) {
       // Set the current basic block to the mbb we wish to insert the code into
@@ -1821,30 +1821,30 @@ SelectionDAGISel::FinishBasicBlock() {
       }
     }
   }
-  SDB->BitTestCases.clear();
+  SDB->SL->BitTestCases.clear();
 
   // If the JumpTable record is filled in, then we need to emit a jump table.
   // Updating the PHI nodes is tricky in this case, since we need to determine
   // whether the PHI is a successor of the range check MBB or the jump table MBB
-  for (unsigned i = 0, e = SDB->JTCases.size(); i != e; ++i) {
+  for (unsigned i = 0, e = SDB->SL->JTCases.size(); i != e; ++i) {
     // Lower header first, if it wasn't already lowered
-    if (!SDB->JTCases[i].first.Emitted) {
+    if (!SDB->SL->JTCases[i].first.Emitted) {
       // Set the current basic block to the mbb we wish to insert the code into
-      FuncInfo->MBB = SDB->JTCases[i].first.HeaderBB;
+      FuncInfo->MBB = SDB->SL->JTCases[i].first.HeaderBB;
       FuncInfo->InsertPt = FuncInfo->MBB->end();
       // Emit the code
-      SDB->visitJumpTableHeader(SDB->JTCases[i].second, SDB->JTCases[i].first,
-                                FuncInfo->MBB);
+      SDB->visitJumpTableHeader(SDB->SL->JTCases[i].second,
+                                SDB->SL->JTCases[i].first, FuncInfo->MBB);
       CurDAG->setRoot(SDB->getRoot());
       SDB->clear();
       CodeGenAndEmitDAG();
     }
 
     // Set the current basic block to the mbb we wish to insert the code into
-    FuncInfo->MBB = SDB->JTCases[i].second.MBB;
+    FuncInfo->MBB = SDB->SL->JTCases[i].second.MBB;
     FuncInfo->InsertPt = FuncInfo->MBB->end();
     // Emit the code
-    SDB->visitJumpTable(SDB->JTCases[i].second);
+    SDB->visitJumpTable(SDB->SL->JTCases[i].second);
     CurDAG->setRoot(SDB->getRoot());
     SDB->clear();
     CodeGenAndEmitDAG();
@@ -1857,31 +1857,31 @@ SelectionDAGISel::FinishBasicBlock() {
       assert(PHI->isPHI() &&
              "This is not a machine PHI node that we are updating!");
       // "default" BB. We can go there only from header BB.
-      if (PHIBB == SDB->JTCases[i].second.Default)
+      if (PHIBB == SDB->SL->JTCases[i].second.Default)
         PHI.addReg(FuncInfo->PHINodesToUpdate[pi].second)
-           .addMBB(SDB->JTCases[i].first.HeaderBB);
+           .addMBB(SDB->SL->JTCases[i].first.HeaderBB);
       // JT BB. Just iterate over successors here
       if (FuncInfo->MBB->isSuccessor(PHIBB))
         PHI.addReg(FuncInfo->PHINodesToUpdate[pi].second).addMBB(FuncInfo->MBB);
     }
   }
-  SDB->JTCases.clear();
+  SDB->SL->JTCases.clear();
 
   // If we generated any switch lowering information, build and codegen any
   // additional DAGs necessary.
-  for (unsigned i = 0, e = SDB->SwitchCases.size(); i != e; ++i) {
+  for (unsigned i = 0, e = SDB->SL->SwitchCases.size(); i != e; ++i) {
     // Set the current basic block to the mbb we wish to insert the code into
-    FuncInfo->MBB = SDB->SwitchCases[i].ThisBB;
+    FuncInfo->MBB = SDB->SL->SwitchCases[i].ThisBB;
     FuncInfo->InsertPt = FuncInfo->MBB->end();
 
     // Determine the unique successors.
     SmallVector<MachineBasicBlock *, 2> Succs;
-    Succs.push_back(SDB->SwitchCases[i].TrueBB);
-    if (SDB->SwitchCases[i].TrueBB != SDB->SwitchCases[i].FalseBB)
-      Succs.push_back(SDB->SwitchCases[i].FalseBB);
+    Succs.push_back(SDB->SL->SwitchCases[i].TrueBB);
+    if (SDB->SL->SwitchCases[i].TrueBB != SDB->SL->SwitchCases[i].FalseBB)
+      Succs.push_back(SDB->SL->SwitchCases[i].FalseBB);
 
     // Emit the code. Note that this could result in FuncInfo->MBB being split.
-    SDB->visitSwitchCase(SDB->SwitchCases[i], FuncInfo->MBB);
+    SDB->visitSwitchCase(SDB->SL->SwitchCases[i], FuncInfo->MBB);
     CurDAG->setRoot(SDB->getRoot());
     SDB->clear();
     CodeGenAndEmitDAG();
@@ -1917,7 +1917,7 @@ SelectionDAGISel::FinishBasicBlock() {
       }
     }
   }
-  SDB->SwitchCases.clear();
+  SDB->SL->SwitchCases.clear();
 }
 
 /// Create the scheduler. If a specific scheduler was specified

Added: llvm/trunk/lib/CodeGen/SwitchLoweringUtils.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SwitchLoweringUtils.cpp?rev=362857&view=auto
==============================================================================
--- llvm/trunk/lib/CodeGen/SwitchLoweringUtils.cpp (added)
+++ llvm/trunk/lib/CodeGen/SwitchLoweringUtils.cpp Fri Jun  7 17:05:17 2019
@@ -0,0 +1,486 @@
+//===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains switch inst lowering optimizations and utilities for
+// codegen, so that it can be used for both SelectionDAG and GlobalISel.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/MachineJumpTableInfo.h"
+#include "llvm/CodeGen/SwitchLoweringUtils.h"
+
+using namespace llvm;
+using namespace SwitchCG;
+
+uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
+                                     unsigned First, unsigned Last) {
+  assert(Last >= First);
+  const APInt &LowCase = Clusters[First].Low->getValue();
+  const APInt &HighCase = Clusters[Last].High->getValue();
+  assert(LowCase.getBitWidth() == HighCase.getBitWidth());
+
+  // FIXME: A range of consecutive cases has 100% density, but only requires one
+  // comparison to lower. We should discriminate against such consecutive ranges
+  // in jump tables.
+
+  return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
+}
+
+uint64_t
+SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
+                               unsigned First, unsigned Last) {
+  assert(Last >= First);
+  assert(TotalCases[Last] >= TotalCases[First]);
+  uint64_t NumCases =
+      TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
+  return NumCases;
+}
+
+void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
+                                              const SwitchInst *SI,
+                                              MachineBasicBlock *DefaultMBB) {
+#ifndef NDEBUG
+  // Clusters must be non-empty, sorted, and only contain Range clusters.
+  assert(!Clusters.empty());
+  for (CaseCluster &C : Clusters)
+    assert(C.Kind == CC_Range);
+  for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
+    assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
+#endif
+
+  if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
+    return;
+
+  const int64_t N = Clusters.size();
+  const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
+  const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
+
+  if (N < 2 || N < MinJumpTableEntries)
+    return;
+
+  // TotalCases[i]: Total nbr of cases in Clusters[0..i].
+  SmallVector<unsigned, 8> TotalCases(N);
+  for (unsigned i = 0; i < N; ++i) {
+    const APInt &Hi = Clusters[i].High->getValue();
+    const APInt &Lo = Clusters[i].Low->getValue();
+    TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
+    if (i != 0)
+      TotalCases[i] += TotalCases[i - 1];
+  }
+
+  // Cheap case: the whole range may be suitable for jump table.
+  uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
+  uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
+  assert(NumCases < UINT64_MAX / 100);
+  assert(Range >= NumCases);
+  if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
+    CaseCluster JTCluster;
+    if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
+      Clusters[0] = JTCluster;
+      Clusters.resize(1);
+      return;
+    }
+  }
+
+  // The algorithm below is not suitable for -O0.
+  if (TM->getOptLevel() == CodeGenOpt::None)
+    return;
+
+  // Split Clusters into minimum number of dense partitions. The algorithm uses
+  // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
+  // for the Case Statement'" (1994), but builds the MinPartitions array in
+  // reverse order to make it easier to reconstruct the partitions in ascending
+  // order. In the choice between two optimal partitionings, it picks the one
+  // which yields more jump tables.
+
+  // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
+  SmallVector<unsigned, 8> MinPartitions(N);
+  // LastElement[i] is the last element of the partition starting at i.
+  SmallVector<unsigned, 8> LastElement(N);
+  // PartitionsScore[i] is used to break ties when choosing between two
+  // partitionings resulting in the same number of partitions.
+  SmallVector<unsigned, 8> PartitionsScore(N);
+  // For PartitionsScore, a small number of comparisons is considered as good as
+  // a jump table and a single comparison is considered better than a jump
+  // table.
+  enum PartitionScores : unsigned {
+    NoTable = 0,
+    Table = 1,
+    FewCases = 1,
+    SingleCase = 2
+  };
+
+  // Base case: There is only one way to partition Clusters[N-1].
+  MinPartitions[N - 1] = 1;
+  LastElement[N - 1] = N - 1;
+  PartitionsScore[N - 1] = PartitionScores::SingleCase;
+
+  // Note: loop indexes are signed to avoid underflow.
+  for (int64_t i = N - 2; i >= 0; i--) {
+    // Find optimal partitioning of Clusters[i..N-1].
+    // Baseline: Put Clusters[i] into a partition on its own.
+    MinPartitions[i] = MinPartitions[i + 1] + 1;
+    LastElement[i] = i;
+    PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
+
+    // Search for a solution that results in fewer partitions.
+    for (int64_t j = N - 1; j > i; j--) {
+      // Try building a partition from Clusters[i..j].
+      uint64_t Range = getJumpTableRange(Clusters, i, j);
+      uint64_t NumCases = getJumpTableNumCases(TotalCases, i, j);
+      assert(NumCases < UINT64_MAX / 100);
+      assert(Range >= NumCases);
+      if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
+        unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
+        unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
+        int64_t NumEntries = j - i + 1;
+
+        if (NumEntries == 1)
+          Score += PartitionScores::SingleCase;
+        else if (NumEntries <= SmallNumberOfEntries)
+          Score += PartitionScores::FewCases;
+        else if (NumEntries >= MinJumpTableEntries)
+          Score += PartitionScores::Table;
+
+        // If this leads to fewer partitions, or to the same number of
+        // partitions with better score, it is a better partitioning.
+        if (NumPartitions < MinPartitions[i] ||
+            (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
+          MinPartitions[i] = NumPartitions;
+          LastElement[i] = j;
+          PartitionsScore[i] = Score;
+        }
+      }
+    }
+  }
+
+  // Iterate over the partitions, replacing some with jump tables in-place.
+  unsigned DstIndex = 0;
+  for (unsigned First = 0, Last; First < N; First = Last + 1) {
+    Last = LastElement[First];
+    assert(Last >= First);
+    assert(DstIndex <= First);
+    unsigned NumClusters = Last - First + 1;
+
+    CaseCluster JTCluster;
+    if (NumClusters >= MinJumpTableEntries &&
+        buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
+      Clusters[DstIndex++] = JTCluster;
+    } else {
+      for (unsigned I = First; I <= Last; ++I)
+        std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
+    }
+  }
+  Clusters.resize(DstIndex);
+}
+
+bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
+                                              unsigned First, unsigned Last,
+                                              const SwitchInst *SI,
+                                              MachineBasicBlock *DefaultMBB,
+                                              CaseCluster &JTCluster) {
+  assert(First <= Last);
+
+  auto Prob = BranchProbability::getZero();
+  unsigned NumCmps = 0;
+  std::vector<MachineBasicBlock*> Table;
+  DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
+
+  // Initialize probabilities in JTProbs.
+  for (unsigned I = First; I <= Last; ++I)
+    JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
+
+  for (unsigned I = First; I <= Last; ++I) {
+    assert(Clusters[I].Kind == CC_Range);
+    Prob += Clusters[I].Prob;
+    const APInt &Low = Clusters[I].Low->getValue();
+    const APInt &High = Clusters[I].High->getValue();
+    NumCmps += (Low == High) ? 1 : 2;
+    if (I != First) {
+      // Fill the gap between this and the previous cluster.
+      const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
+      assert(PreviousHigh.slt(Low));
+      uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
+      for (uint64_t J = 0; J < Gap; J++)
+        Table.push_back(DefaultMBB);
+    }
+    uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
+    for (uint64_t J = 0; J < ClusterSize; ++J)
+      Table.push_back(Clusters[I].MBB);
+    JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
+  }
+
+  unsigned NumDests = JTProbs.size();
+  if (TLI->isSuitableForBitTests(NumDests, NumCmps,
+                                 Clusters[First].Low->getValue(),
+                                 Clusters[Last].High->getValue(), *DL)) {
+    // Clusters[First..Last] should be lowered as bit tests instead.
+    return false;
+  }
+
+  // Create the MBB that will load from and jump through the table.
+  // Note: We create it here, but it's not inserted into the function yet.
+  MachineFunction *CurMF = FuncInfo.MF;
+  MachineBasicBlock *JumpTableMBB =
+      CurMF->CreateMachineBasicBlock(SI->getParent());
+
+  // Add successors. Note: use table order for determinism.
+  SmallPtrSet<MachineBasicBlock *, 8> Done;
+  for (MachineBasicBlock *Succ : Table) {
+    if (Done.count(Succ))
+      continue;
+    addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
+    Done.insert(Succ);
+  }
+  JumpTableMBB->normalizeSuccProbs();
+
+  unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
+                     ->createJumpTableIndex(Table);
+
+  // Set up the jump table info.
+  JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
+  JumpTableHeader JTH(Clusters[First].Low->getValue(),
+                      Clusters[Last].High->getValue(), SI->getCondition(),
+                      nullptr, false);
+  JTCases.emplace_back(std::move(JTH), std::move(JT));
+
+  JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
+                                     JTCases.size() - 1, Prob);
+  return true;
+}
+
+void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
+                                                   const SwitchInst *SI) {
+  // Partition Clusters into as few subsets as possible, where each subset has a
+  // range that fits in a machine word and has <= 3 unique destinations.
+
+#ifndef NDEBUG
+  // Clusters must be sorted and contain Range or JumpTable clusters.
+  assert(!Clusters.empty());
+  assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
+  for (const CaseCluster &C : Clusters)
+    assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
+  for (unsigned i = 1; i < Clusters.size(); ++i)
+    assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
+#endif
+
+  // The algorithm below is not suitable for -O0.
+  if (TM->getOptLevel() == CodeGenOpt::None)
+    return;
+
+  // If target does not have legal shift left, do not emit bit tests at all.
+  EVT PTy = TLI->getPointerTy(*DL);
+  if (!TLI->isOperationLegal(ISD::SHL, PTy))
+    return;
+
+  int BitWidth = PTy.getSizeInBits();
+  const int64_t N = Clusters.size();
+
+  // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
+  SmallVector<unsigned, 8> MinPartitions(N);
+  // LastElement[i] is the last element of the partition starting at i.
+  SmallVector<unsigned, 8> LastElement(N);
+
+  // FIXME: This might not be the best algorithm for finding bit test clusters.
+
+  // Base case: There is only one way to partition Clusters[N-1].
+  MinPartitions[N - 1] = 1;
+  LastElement[N - 1] = N - 1;
+
+  // Note: loop indexes are signed to avoid underflow.
+  for (int64_t i = N - 2; i >= 0; --i) {
+    // Find optimal partitioning of Clusters[i..N-1].
+    // Baseline: Put Clusters[i] into a partition on its own.
+    MinPartitions[i] = MinPartitions[i + 1] + 1;
+    LastElement[i] = i;
+
+    // Search for a solution that results in fewer partitions.
+    // Note: the search is limited by BitWidth, reducing time complexity.
+    for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
+      // Try building a partition from Clusters[i..j].
+
+      // Check the range.
+      if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
+                                Clusters[j].High->getValue(), *DL))
+        continue;
+
+      // Check nbr of destinations and cluster types.
+      // FIXME: This works, but doesn't seem very efficient.
+      bool RangesOnly = true;
+      BitVector Dests(FuncInfo.MF->getNumBlockIDs());
+      for (int64_t k = i; k <= j; k++) {
+        if (Clusters[k].Kind != CC_Range) {
+          RangesOnly = false;
+          break;
+        }
+        Dests.set(Clusters[k].MBB->getNumber());
+      }
+      if (!RangesOnly || Dests.count() > 3)
+        break;
+
+      // Check if it's a better partition.
+      unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
+      if (NumPartitions < MinPartitions[i]) {
+        // Found a better partition.
+        MinPartitions[i] = NumPartitions;
+        LastElement[i] = j;
+      }
+    }
+  }
+
+  // Iterate over the partitions, replacing with bit-test clusters in-place.
+  unsigned DstIndex = 0;
+  for (unsigned First = 0, Last; First < N; First = Last + 1) {
+    Last = LastElement[First];
+    assert(First <= Last);
+    assert(DstIndex <= First);
+
+    CaseCluster BitTestCluster;
+    if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
+      Clusters[DstIndex++] = BitTestCluster;
+    } else {
+      size_t NumClusters = Last - First + 1;
+      std::memmove(&Clusters[DstIndex], &Clusters[First],
+                   sizeof(Clusters[0]) * NumClusters);
+      DstIndex += NumClusters;
+    }
+  }
+  Clusters.resize(DstIndex);
+}
+
+bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
+                                             unsigned First, unsigned Last,
+                                             const SwitchInst *SI,
+                                             CaseCluster &BTCluster) {
+  assert(First <= Last);
+  if (First == Last)
+    return false;
+
+  BitVector Dests(FuncInfo.MF->getNumBlockIDs());
+  unsigned NumCmps = 0;
+  for (int64_t I = First; I <= Last; ++I) {
+    assert(Clusters[I].Kind == CC_Range);
+    Dests.set(Clusters[I].MBB->getNumber());
+    NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
+  }
+  unsigned NumDests = Dests.count();
+
+  APInt Low = Clusters[First].Low->getValue();
+  APInt High = Clusters[Last].High->getValue();
+  assert(Low.slt(High));
+
+  if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
+    return false;
+
+  APInt LowBound;
+  APInt CmpRange;
+
+  const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
+  assert(TLI->rangeFitsInWord(Low, High, *DL) &&
+         "Case range must fit in bit mask!");
+
+  // Check if the clusters cover a contiguous range such that no value in the
+  // range will jump to the default statement.
+  bool ContiguousRange = true;
+  for (int64_t I = First + 1; I <= Last; ++I) {
+    if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
+      ContiguousRange = false;
+      break;
+    }
+  }
+
+  if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
+    // Optimize the case where all the case values fit in a word without having
+    // to subtract minValue. In this case, we can optimize away the subtraction.
+    LowBound = APInt::getNullValue(Low.getBitWidth());
+    CmpRange = High;
+    ContiguousRange = false;
+  } else {
+    LowBound = Low;
+    CmpRange = High - Low;
+  }
+
+  CaseBitsVector CBV;
+  auto TotalProb = BranchProbability::getZero();
+  for (unsigned i = First; i <= Last; ++i) {
+    // Find the CaseBits for this destination.
+    unsigned j;
+    for (j = 0; j < CBV.size(); ++j)
+      if (CBV[j].BB == Clusters[i].MBB)
+        break;
+    if (j == CBV.size())
+      CBV.push_back(
+          CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
+    CaseBits *CB = &CBV[j];
+
+    // Update Mask, Bits and ExtraProb.
+    uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
+    uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
+    assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
+    CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
+    CB->Bits += Hi - Lo + 1;
+    CB->ExtraProb += Clusters[i].Prob;
+    TotalProb += Clusters[i].Prob;
+  }
+
+  BitTestInfo BTI;
+  llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
+    // Sort by probability first, number of bits second, bit mask third.
+    if (a.ExtraProb != b.ExtraProb)
+      return a.ExtraProb > b.ExtraProb;
+    if (a.Bits != b.Bits)
+      return a.Bits > b.Bits;
+    return a.Mask < b.Mask;
+  });
+
+  for (auto &CB : CBV) {
+    MachineBasicBlock *BitTestBB =
+        FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
+    BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
+  }
+  BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
+                            SI->getCondition(), -1U, MVT::Other, false,
+                            ContiguousRange, nullptr, nullptr, std::move(BTI),
+                            TotalProb);
+
+  BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
+                                    BitTestCases.size() - 1, TotalProb);
+  return true;
+}
+
+void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
+#ifndef NDEBUG
+  for (const CaseCluster &CC : Clusters)
+    assert(CC.Low == CC.High && "Input clusters must be single-case");
+#endif
+
+  llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
+    return a.Low->getValue().slt(b.Low->getValue());
+  });
+
+  // Merge adjacent clusters with the same destination.
+  const unsigned N = Clusters.size();
+  unsigned DstIndex = 0;
+  for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
+    CaseCluster &CC = Clusters[SrcIndex];
+    const ConstantInt *CaseVal = CC.Low;
+    MachineBasicBlock *Succ = CC.MBB;
+
+    if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
+        (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
+      // If this case has the same successor and is a neighbour, merge it into
+      // the previous cluster.
+      Clusters[DstIndex - 1].High = CaseVal;
+      Clusters[DstIndex - 1].Prob += CC.Prob;
+    } else {
+      std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
+                   sizeof(Clusters[SrcIndex]));
+    }
+  }
+  Clusters.resize(DstIndex);
+}




More information about the llvm-commits mailing list