[llvm] r301649 - [InlineCost] Improve the cost heuristic for Switch

Jun Bum Lim via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 28 09:04:04 PDT 2017


Author: junbuml
Date: Fri Apr 28 11:04:03 2017
New Revision: 301649

URL: http://llvm.org/viewvc/llvm-project?rev=301649&view=rev
Log:
[InlineCost] Improve the cost heuristic for Switch

Summary:
The motivation example is like below which has 13 cases but only 2 distinct targets

```
lor.lhs.false2:                                   ; preds = %if.then
  switch i32 %Status, label %if.then27 [
    i32 -7012, label %if.end35
    i32 -10008, label %if.end35
    i32 -10016, label %if.end35
    i32 15000, label %if.end35
    i32 14013, label %if.end35
    i32 10114, label %if.end35
    i32 10107, label %if.end35
    i32 10105, label %if.end35
    i32 10013, label %if.end35
    i32 10011, label %if.end35
    i32 7008, label %if.end35
    i32 7007, label %if.end35
    i32 5002, label %if.end35
  ]
```
which is compiled into a balanced binary tree like this on AArch64 (similar on X86)

```
.LBB853_9:                              // %lor.lhs.false2
        mov     w8, #10012
        cmp             w19, w8
        b.gt    .LBB853_14
// BB#10:                               // %lor.lhs.false2
        mov     w8, #5001
        cmp             w19, w8
        b.gt    .LBB853_18
// BB#11:                               // %lor.lhs.false2
        mov     w8, #-10016
        cmp             w19, w8
        b.eq    .LBB853_23
// BB#12:                               // %lor.lhs.false2
        mov     w8, #-10008
        cmp             w19, w8
        b.eq    .LBB853_23
// BB#13:                               // %lor.lhs.false2
        mov     w8, #-7012
        cmp             w19, w8
        b.eq    .LBB853_23
        b       .LBB853_3
.LBB853_14:                             // %lor.lhs.false2
        mov     w8, #14012
        cmp             w19, w8
        b.gt    .LBB853_21
// BB#15:                               // %lor.lhs.false2
        mov     w8, #-10105
        add             w8, w19, w8
        cmp             w8, #9          // =9
        b.hi    .LBB853_17
// BB#16:                               // %lor.lhs.false2
        orr     w9, wzr, #0x1
        lsl     w8, w9, w8
        mov     w9, #517
        and             w8, w8, w9
        cbnz    w8, .LBB853_23
.LBB853_17:                             // %lor.lhs.false2
        mov     w8, #10013
        cmp             w19, w8
        b.eq    .LBB853_23
        b       .LBB853_3
.LBB853_18:                             // %lor.lhs.false2
        mov     w8, #-7007
        add             w8, w19, w8
        cmp             w8, #2          // =2
        b.lo    .LBB853_23
// BB#19:                               // %lor.lhs.false2
        mov     w8, #5002
        cmp             w19, w8
        b.eq    .LBB853_23
// BB#20:                               // %lor.lhs.false2
        mov     w8, #10011
        cmp             w19, w8
        b.eq    .LBB853_23
        b       .LBB853_3
.LBB853_21:                             // %lor.lhs.false2
        mov     w8, #14013
        cmp             w19, w8
        b.eq    .LBB853_23
// BB#22:                               // %lor.lhs.false2
        mov     w8, #15000
        cmp             w19, w8
        b.ne    .LBB853_3
```
However, the inline cost model estimates the cost to be linear with the number
of distinct targets and the cost of the above switch is just 2 InstrCosts.
The function containing this switch is then inlined about 900 times.

This change use the general way of switch lowering for the inline heuristic. It
etimate the number of case clusters with the suitability check for a jump table
or bit test. Considering the binary search tree built for the clusters, this
change modifies the model to be linear with the size of the balanced binary
tree. The model is off by default for now :
  -inline-generic-switch-cost=false

This change was originally proposed by Haicheng in D29870.

Reviewers: hans, bmakam, chandlerc, eraman, haicheng, mcrosier

Reviewed By: hans

Subscribers: joerg, aemerson, llvm-commits, rengolin

Differential Revision: https://reviews.llvm.org/D31085

Added:
    llvm/trunk/test/Transforms/Inline/AArch64/switch.ll
Modified:
    llvm/trunk/include/llvm/Analysis/TargetTransformInfo.h
    llvm/trunk/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/trunk/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/trunk/include/llvm/Target/TargetLowering.h
    llvm/trunk/lib/Analysis/InlineCost.cpp
    llvm/trunk/lib/Analysis/TargetTransformInfo.cpp
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
    llvm/trunk/lib/CodeGen/TargetLoweringBase.cpp

Modified: llvm/trunk/include/llvm/Analysis/TargetTransformInfo.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/TargetTransformInfo.h?rev=301649&r1=301648&r2=301649&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Analysis/TargetTransformInfo.h (original)
+++ llvm/trunk/include/llvm/Analysis/TargetTransformInfo.h Fri Apr 28 11:04:03 2017
@@ -197,6 +197,12 @@ public:
   int getIntrinsicCost(Intrinsic::ID IID, Type *RetTy,
                        ArrayRef<const Value *> Arguments) const;
 
+  /// \return The estimated number of case clusters when lowering \p 'SI'.
+  /// \p JTSize Set a jump table size only when \p SI is suitable for a jump
+  /// table.
+  unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
+                                            unsigned &JTSize) const;
+
   /// \brief Estimate the cost of a given IR user when lowered.
   ///
   /// This can estimate the cost of either a ConstantExpr or Instruction when
@@ -764,6 +770,8 @@ public:
                                ArrayRef<Type *> ParamTys) = 0;
   virtual int getIntrinsicCost(Intrinsic::ID IID, Type *RetTy,
                                ArrayRef<const Value *> Arguments) = 0;
+  virtual unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
+                                                    unsigned &JTSize) = 0;
   virtual int getUserCost(const User *U) = 0;
   virtual bool hasBranchDivergence() = 0;
   virtual bool isSourceOfDivergence(const Value *V) = 0;
@@ -1067,6 +1075,10 @@ public:
   unsigned getMaxInterleaveFactor(unsigned VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
+  unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
+                                            unsigned &JTSize) override {
+    return Impl.getEstimatedNumberOfCaseClusters(SI, JTSize);
+  }
   unsigned
   getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind Opd1Info,
                          OperandValueKind Opd2Info,

Modified: llvm/trunk/include/llvm/Analysis/TargetTransformInfoImpl.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/TargetTransformInfoImpl.h?rev=301649&r1=301648&r2=301649&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Analysis/TargetTransformInfoImpl.h (original)
+++ llvm/trunk/include/llvm/Analysis/TargetTransformInfoImpl.h Fri Apr 28 11:04:03 2017
@@ -114,6 +114,12 @@ public:
     return TTI::TCC_Free;
   }
 
+  unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
+                                            unsigned &JTSize) {
+    JTSize = 0;
+    return SI.getNumCases();
+  }
+
   unsigned getCallCost(FunctionType *FTy, int NumArgs) {
     assert(FTy && "FunctionType must be provided to this routine.");
 

Modified: llvm/trunk/include/llvm/CodeGen/BasicTTIImpl.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/BasicTTIImpl.h?rev=301649&r1=301648&r2=301649&view=diff
==============================================================================
--- llvm/trunk/include/llvm/CodeGen/BasicTTIImpl.h (original)
+++ llvm/trunk/include/llvm/CodeGen/BasicTTIImpl.h Fri Apr 28 11:04:03 2017
@@ -171,6 +171,62 @@ public:
     return BaseT::getIntrinsicCost(IID, RetTy, ParamTys);
   }
 
+  unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
+                                            unsigned &JumpTableSize) {
+    /// Try to find the estimated number of clusters. Note that the number of
+    /// clusters identified in this function could be different from the actural
+    /// numbers found in lowering. This function ignore switches that are
+    /// lowered with a mix of jump table / bit test / BTree. This function was
+    /// initially intended to be used when estimating the cost of switch in
+    /// inline cost heuristic, but it's a generic cost model to be used in other
+    /// places (e.g., in loop unrolling).
+    unsigned N = SI.getNumCases();
+    const TargetLoweringBase *TLI = getTLI();
+    const DataLayout &DL = this->getDataLayout();
+
+    JumpTableSize = 0;
+    bool IsJTAllowed = TLI->areJTsAllowed(SI.getParent()->getParent());
+
+    // Early exit if both a jump table and bit test are not allowed.
+    if (N < 1 || (!IsJTAllowed && DL.getPointerSizeInBits() < N))
+      return N;
+
+    APInt MaxCaseVal = SI.case_begin()->getCaseValue()->getValue();
+    APInt MinCaseVal = MaxCaseVal;
+    for (auto CI : SI.cases()) {
+      const APInt &CaseVal = CI.getCaseValue()->getValue();
+      if (CaseVal.sgt(MaxCaseVal))
+        MaxCaseVal = CaseVal;
+      if (CaseVal.slt(MinCaseVal))
+        MinCaseVal = CaseVal;
+    }
+
+    // Check if suitable for a bit test
+    if (N <= DL.getPointerSizeInBits()) {
+      SmallPtrSet<const BasicBlock *, 4> Dests;
+      for (auto I : SI.cases())
+        Dests.insert(I.getCaseSuccessor());
+
+      if (TLI->isSuitableForBitTests(Dests.size(), N, MinCaseVal, MaxCaseVal,
+                                     DL))
+        return 1;
+    }
+
+    // Check if suitable for a jump table.
+    if (IsJTAllowed) {
+      if (N < 2 || N < TLI->getMinimumJumpTableEntries())
+        return N;
+      uint64_t Range =
+          (MaxCaseVal - MinCaseVal).getLimitedValue(UINT64_MAX - 1) + 1;
+      // Check whether a range of clusters is dense enough for a jump table
+      if (TLI->isSuitableForJumpTable(&SI, N, Range)) {
+        JumpTableSize = Range;
+        return 1;
+      }
+    }
+    return N;
+  }
+
   unsigned getJumpBufAlignment() { return getTLI()->getJumpBufAlignment(); }
 
   unsigned getJumpBufSize() { return getTLI()->getJumpBufSize(); }

Modified: llvm/trunk/include/llvm/Target/TargetLowering.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Target/TargetLowering.h?rev=301649&r1=301648&r2=301649&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Target/TargetLowering.h (original)
+++ llvm/trunk/include/llvm/Target/TargetLowering.h Fri Apr 28 11:04:03 2017
@@ -775,6 +775,74 @@ public:
     return (!isTypeLegal(VT) && getOperationAction(Op, VT) == Custom);
   }
 
+  /// Return true if lowering to a jump table is allowed.
+  bool areJTsAllowed(const Function *Fn) const {
+    if (Fn->getFnAttribute("no-jump-tables").getValueAsString() == "true")
+      return false;
+
+    return isOperationLegalOrCustom(ISD::BR_JT, MVT::Other) ||
+           isOperationLegalOrCustom(ISD::BRIND, MVT::Other);
+  }
+
+  /// Check whether the range [Low,High] fits in a machine word.
+  bool rangeFitsInWord(const APInt &Low, const APInt &High,
+                       const DataLayout &DL) const {
+    // FIXME: Using the pointer type doesn't seem ideal.
+    uint64_t BW = DL.getPointerSizeInBits();
+    uint64_t Range = (High - Low).getLimitedValue(UINT64_MAX - 1) + 1;
+    return Range <= BW;
+  }
+
+  /// Return true if lowering to a jump table is suitable for a set of case
+  /// clusters which may contain \p NumCases cases, \p Range range of values.
+  /// FIXME: This function check the maximum table size and density, but the
+  /// minimum size is not checked. It would be nice if the the minimum size is
+  /// also combined within this function. Currently, the minimum size check is
+  /// performed in findJumpTable() in SelectionDAGBuiler and
+  /// getEstimatedNumberOfCaseClusters() in BasicTTIImpl.
+  bool isSuitableForJumpTable(const SwitchInst *SI, uint64_t NumCases,
+                              uint64_t Range) const {
+    const bool OptForSize = SI->getParent()->getParent()->optForSize();
+    const unsigned MinDensity = getMinimumJumpTableDensity(OptForSize);
+    const unsigned MaxJumpTableSize =
+        OptForSize || getMaximumJumpTableSize() == 0
+            ? UINT_MAX
+            : getMaximumJumpTableSize();
+    // Check whether a range of clusters is dense enough for a jump table.
+    if (Range <= MaxJumpTableSize &&
+        (NumCases * 100 >= Range * MinDensity)) {
+      return true;
+    }
+    return false;
+  }
+
+  /// Return true if lowering to a bit test is suitable for a set of case
+  /// clusters which contains \p NumDests unique destinations, \p Low and
+  /// \p High as its lowest and highest case values, and expects \p NumCmps
+  /// case value comparisons. Check if the number of destinations, comparison
+  /// metric, and range are all suitable.
+  bool isSuitableForBitTests(unsigned NumDests, unsigned NumCmps,
+                             const APInt &Low, const APInt &High,
+                             const DataLayout &DL) const {
+    // FIXME: I don't think NumCmps is the correct metric: a single case and a
+    // range of cases both require only one branch to lower. Just looking at the
+    // number of clusters and destinations should be enough to decide whether to
+    // build bit tests.
+
+    // To lower a range with bit tests, the range must fit the bitwidth of a
+    // machine word.
+    if (!rangeFitsInWord(Low, High, DL))
+      return false;
+
+    // Decide whether it's profitable to lower this range with bit tests. Each
+    // destination requires a bit test and branch, and there is an overall range
+    // check branch. For a small number of clusters, separate comparisons might
+    // be cheaper, and for many destinations, splitting the range might be
+    // better.
+    return (NumDests == 1 && NumCmps >= 3) || (NumDests == 2 && NumCmps >= 5) ||
+           (NumDests == 3 && NumCmps >= 6);
+  }
+
   /// Return true if the specified operation is illegal on this target or
   /// unlikely to be made legal with custom lowering. This is used to help guide
   /// high-level lowering decisions.
@@ -1149,6 +1217,9 @@ public:
   /// Return lower limit for number of blocks in a jump table.
   unsigned getMinimumJumpTableEntries() const;
 
+  /// Return lower limit of the density in a jump table.
+  unsigned getMinimumJumpTableDensity(bool OptForSize) const;
+
   /// Return upper limit for number of entries in a jump table.
   /// Zero if no limit.
   unsigned getMaximumJumpTableSize() const;

Modified: llvm/trunk/lib/Analysis/InlineCost.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/InlineCost.cpp?rev=301649&r1=301648&r2=301649&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/InlineCost.cpp (original)
+++ llvm/trunk/lib/Analysis/InlineCost.cpp Fri Apr 28 11:04:03 2017
@@ -54,6 +54,11 @@ static cl::opt<int>
                           cl::init(45),
                           cl::desc("Threshold for inlining cold callsites"));
 
+static cl::opt<bool>
+    EnableGenericSwitchCost("inline-generic-switch-cost", cl::Hidden,
+                            cl::init(false),
+                            cl::desc("Enable generic switch cost model"));
+
 // We introduce this threshold to help performance of instrumentation based
 // PGO before we actually hook up inliner with analysis passes such as BPI and
 // BFI.
@@ -998,11 +1003,72 @@ bool CallAnalyzer::visitSwitchInst(Switc
     if (isa<ConstantInt>(V))
       return true;
 
-  // Otherwise, we need to accumulate a cost proportional to the number of
-  // distinct successor blocks. This fan-out in the CFG cannot be represented
-  // for free even if we can represent the core switch as a jumptable that
-  // takes a single instruction.
-  //
+  if (EnableGenericSwitchCost) {
+    // Assume the most general case where the swith is lowered into
+    // either a jump table, bit test, or a balanced binary tree consisting of
+    // case clusters without merging adjacent clusters with the same
+    // destination. We do not consider the switches that are lowered with a mix
+    // of jump table/bit test/binary search tree. The cost of the switch is
+    // proportional to the size of the tree or the size of jump table range.
+
+    // Exit early for a large switch, assuming one case needs at least one
+    // instruction.
+    // FIXME: This is not true for a bit test, but ignore such case for now to
+    // save compile-time.
+    int64_t CostLowerBound =
+        std::min((int64_t)INT_MAX,
+                 (int64_t)SI.getNumCases() * InlineConstants::InstrCost + Cost);
+
+    if (CostLowerBound > Threshold) {
+      Cost = CostLowerBound;
+      return false;
+    }
+
+    unsigned JumpTableSize = 0;
+    unsigned NumCaseCluster =
+        TTI.getEstimatedNumberOfCaseClusters(SI, JumpTableSize);
+
+    // If suitable for a jump table, consider the cost for the table size and
+    // branch to destination.
+    if (JumpTableSize) {
+      int64_t JTCost = (int64_t)JumpTableSize * InlineConstants::InstrCost +
+                       4 * InlineConstants::InstrCost;
+      Cost = std::min((int64_t)INT_MAX, JTCost + Cost);
+      return false;
+    }
+
+    // Considering forming a binary search, we should find the number of nodes
+    // which is same as the number of comparisons when lowered. For a given
+    // number of clusters, n, we can define a recursive function, f(n), to find
+    // the number of nodes in the tree. The recursion is :
+    // f(n) = 1 + f(n/2) + f (n - n/2), when n > 3,
+    // and f(n) = n, when n <= 3.
+    // This will lead a binary tree where the leaf should be either f(2) or f(3)
+    // when n > 3.  So, the number of comparisons from leaves should be n, while
+    // the number of non-leaf should be :
+    //   2^(log2(n) - 1) - 1
+    //   = 2^log2(n) * 2^-1 - 1
+    //   = n / 2 - 1.
+    // Considering comparisons from leaf and non-leaf nodes, we can estimate the
+    // number of comparisons in a simple closed form :
+    //   n + n / 2 - 1 = n * 3 / 2 - 1
+    if (NumCaseCluster <= 3) {
+      // Suppose a comparison includes one compare and one conditional branch.
+      Cost += NumCaseCluster * 2 * InlineConstants::InstrCost;
+      return false;
+    }
+    int64_t ExpectedNumberOfCompare = 3 * (uint64_t)NumCaseCluster / 2 - 1;
+    uint64_t SwitchCost =
+        ExpectedNumberOfCompare * 2 * InlineConstants::InstrCost;
+    Cost = std::min((uint64_t)INT_MAX, SwitchCost + Cost);
+    return false;
+  }
+
+  // Use a simple switch cost model where we accumulate a cost proportional to
+  // the number of distinct successor blocks. This fan-out in the CFG cannot
+  // be represented for free even if we can represent the core switch as a
+  // jumptable that takes a single instruction.
+  ///
   // NB: We convert large switches which are just used to initialize large phi
   // nodes to lookup tables instead in simplify-cfg, so this shouldn't prevent
   // inlining those. It will prevent inlining in cases where the optimization

Modified: llvm/trunk/lib/Analysis/TargetTransformInfo.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/TargetTransformInfo.cpp?rev=301649&r1=301648&r2=301649&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/TargetTransformInfo.cpp (original)
+++ llvm/trunk/lib/Analysis/TargetTransformInfo.cpp Fri Apr 28 11:04:03 2017
@@ -83,6 +83,12 @@ int TargetTransformInfo::getIntrinsicCos
   return Cost;
 }
 
+unsigned
+TargetTransformInfo::getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
+                                                      unsigned &JTSize) const {
+  return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize);
+}
+
 int TargetTransformInfo::getUserCost(const User *U) const {
   int Cost = TTIImpl->getUserCost(U);
   assert(Cost >= 0 && "TTI should not produce negative costs!");

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp?rev=301649&r1=301648&r2=301649&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp Fri Apr 28 11:04:03 2017
@@ -83,20 +83,6 @@ LimitFPPrecision("limit-float-precision"
                           "for some float libcalls"),
                  cl::location(LimitFloatPrecision),
                  cl::init(0));
-
-/// Minimum jump table density for normal functions.
-static cl::opt<unsigned>
-JumpTableDensity("jump-table-density", cl::init(10), cl::Hidden,
-                 cl::desc("Minimum density for building a jump table in "
-                          "a normal function"));
-
-/// Minimum jump table density for -Os or -Oz functions.
-static cl::opt<unsigned>
-OptsizeJumpTableDensity("optsize-jump-table-density", cl::init(40), cl::Hidden,
-                        cl::desc("Minimum density for building a jump table in "
-                                 "an optsize function"));
-
-
 // Limit the width of DAG chains. This is important in general to prevent
 // DAG-based analysis from blowing up. For example, alias analysis and
 // load clustering may not complete in reasonable time. It is difficult to
@@ -8589,13 +8575,10 @@ void SelectionDAGBuilder::updateDAGForMa
     HasTailCall = true;
 }
 
-bool SelectionDAGBuilder::isDense(const CaseClusterVector &Clusters,
-                                  const SmallVectorImpl<unsigned> &TotalCases,
-                                  unsigned First, unsigned Last,
-                                  unsigned Density) const {
+uint64_t
+SelectionDAGBuilder::getJumpTableRange(const CaseClusterVector &Clusters,
+                                       unsigned First, unsigned Last) const {
   assert(Last >= First);
-  assert(TotalCases[Last] >= TotalCases[First]);
-
   const APInt &LowCase = Clusters[First].Low->getValue();
   const APInt &HighCase = Clusters[Last].High->getValue();
   assert(LowCase.getBitWidth() == HighCase.getBitWidth());
@@ -8604,26 +8587,17 @@ bool SelectionDAGBuilder::isDense(const
   // comparison to lower. We should discriminate against such consecutive ranges
   // in jump tables.
 
-  uint64_t Diff = (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100);
-  uint64_t Range = Diff + 1;
+  return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
+}
 
+uint64_t SelectionDAGBuilder::getJumpTableNumCases(
+    const SmallVectorImpl<unsigned> &TotalCases, unsigned First,
+    unsigned Last) const {
+  assert(Last >= First);
+  assert(TotalCases[Last] >= TotalCases[First]);
   uint64_t NumCases =
       TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
-
-  assert(NumCases < UINT64_MAX / 100);
-  assert(Range >= NumCases);
-
-  return NumCases * 100 >= Range * Density;
-}
-
-static inline bool areJTsAllowed(const TargetLowering &TLI,
-                                 const SwitchInst *SI) {
-  const Function *Fn = SI->getParent()->getParent();
-  if (Fn->getFnAttribute("no-jump-tables").getValueAsString() == "true")
-    return false;
-
-  return TLI.isOperationLegalOrCustom(ISD::BR_JT, MVT::Other) ||
-         TLI.isOperationLegalOrCustom(ISD::BRIND, MVT::Other);
+  return NumCases;
 }
 
 bool SelectionDAGBuilder::buildJumpTable(const CaseClusterVector &Clusters,
@@ -8662,10 +8636,11 @@ bool SelectionDAGBuilder::buildJumpTable
     JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
   }
 
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   unsigned NumDests = JTProbs.size();
-  if (isSuitableForBitTests(NumDests, NumCmps,
-                            Clusters[First].Low->getValue(),
-                            Clusters[Last].High->getValue())) {
+  if (TLI.isSuitableForBitTests(
+          NumDests, NumCmps, Clusters[First].Low->getValue(),
+          Clusters[Last].High->getValue(), DAG.getDataLayout())) {
     // Clusters[First..Last] should be lowered as bit tests instead.
     return false;
   }
@@ -8686,7 +8661,6 @@ bool SelectionDAGBuilder::buildJumpTable
   }
   JumpTableMBB->normalizeSuccProbs();
 
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI.getJumpTableEncoding())
                      ->createJumpTableIndex(Table);
 
@@ -8715,17 +8689,12 @@ void SelectionDAGBuilder::findJumpTables
 #endif
 
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  if (!areJTsAllowed(TLI, SI))
+  if (!TLI.areJTsAllowed(SI->getParent()->getParent()))
     return;
 
-  const bool OptForSize = DefaultMBB->getParent()->getFunction()->optForSize();
-
   const int64_t N = Clusters.size();
   const unsigned MinJumpTableEntries = TLI.getMinimumJumpTableEntries();
   const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
-  const unsigned MaxJumpTableSize =
-                   OptForSize || TLI.getMaximumJumpTableSize() == 0
-                   ? UINT_MAX : TLI.getMaximumJumpTableSize();
 
   if (N < 2 || N < MinJumpTableEntries)
     return;
@@ -8740,15 +8709,12 @@ void SelectionDAGBuilder::findJumpTables
       TotalCases[i] += TotalCases[i - 1];
   }
 
-  const unsigned MinDensity =
-    OptForSize ? OptsizeJumpTableDensity : JumpTableDensity;
-
   // Cheap case: the whole range may be suitable for jump table.
-  unsigned JumpTableSize = (Clusters[N - 1].High->getValue() -
-                            Clusters[0].Low->getValue())
-                           .getLimitedValue(UINT_MAX - 1) + 1;
-  if (JumpTableSize <= MaxJumpTableSize &&
-      isDense(Clusters, TotalCases, 0, N - 1, MinDensity)) {
+  uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
+  uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
+  assert(NumCases < UINT64_MAX / 100);
+  assert(Range >= NumCases);
+  if (TLI.isSuitableForJumpTable(SI, NumCases, Range)) {
     CaseCluster JTCluster;
     if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
       Clusters[0] = JTCluster;
@@ -8801,11 +8767,11 @@ void SelectionDAGBuilder::findJumpTables
     // Search for a solution that results in fewer partitions.
     for (int64_t j = N - 1; j > i; j--) {
       // Try building a partition from Clusters[i..j].
-      JumpTableSize = (Clusters[j].High->getValue() -
-                       Clusters[i].Low->getValue())
-                      .getLimitedValue(UINT_MAX - 1) + 1;
-      if (JumpTableSize <= MaxJumpTableSize &&
-          isDense(Clusters, TotalCases, i, j, MinDensity)) {
+      uint64_t Range = getJumpTableRange(Clusters, i, j);
+      uint64_t NumCases = getJumpTableNumCases(TotalCases, i, j);
+      assert(NumCases < UINT64_MAX / 100);
+      assert(Range >= NumCases);
+      if (TLI.isSuitableForJumpTable(SI, NumCases, Range)) {
         unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
         unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
         int64_t NumEntries = j - i + 1;
@@ -8849,36 +8815,6 @@ void SelectionDAGBuilder::findJumpTables
   Clusters.resize(DstIndex);
 }
 
-bool SelectionDAGBuilder::rangeFitsInWord(const APInt &Low, const APInt &High) {
-  // FIXME: Using the pointer type doesn't seem ideal.
-  uint64_t BW = DAG.getDataLayout().getPointerSizeInBits();
-  uint64_t Range = (High - Low).getLimitedValue(UINT64_MAX - 1) + 1;
-  return Range <= BW;
-}
-
-bool SelectionDAGBuilder::isSuitableForBitTests(unsigned NumDests,
-                                                unsigned NumCmps,
-                                                const APInt &Low,
-                                                const APInt &High) {
-  // FIXME: I don't think NumCmps is the correct metric: a single case and a
-  // range of cases both require only one branch to lower. Just looking at the
-  // number of clusters and destinations should be enough to decide whether to
-  // build bit tests.
-
-  // To lower a range with bit tests, the range must fit the bitwidth of a
-  // machine word.
-  if (!rangeFitsInWord(Low, High))
-    return false;
-
-  // Decide whether it's profitable to lower this range with bit tests. Each
-  // destination requires a bit test and branch, and there is an overall range
-  // check branch. For a small number of clusters, separate comparisons might be
-  // cheaper, and for many destinations, splitting the range might be better.
-  return (NumDests == 1 && NumCmps >= 3) ||
-         (NumDests == 2 && NumCmps >= 5) ||
-         (NumDests == 3 && NumCmps >= 6);
-}
-
 bool SelectionDAGBuilder::buildBitTests(CaseClusterVector &Clusters,
                                         unsigned First, unsigned Last,
                                         const SwitchInst *SI,
@@ -8900,16 +8836,17 @@ bool SelectionDAGBuilder::buildBitTests(
   APInt High = Clusters[Last].High->getValue();
   assert(Low.slt(High));
 
-  if (!isSuitableForBitTests(NumDests, NumCmps, Low, High))
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  const DataLayout &DL = DAG.getDataLayout();
+  if (!TLI.isSuitableForBitTests(NumDests, NumCmps, Low, High, DL))
     return false;
 
   APInt LowBound;
   APInt CmpRange;
 
-  const int BitWidth = DAG.getTargetLoweringInfo()
-                           .getPointerTy(DAG.getDataLayout())
-                           .getSizeInBits();
-  assert(rangeFitsInWord(Low, High) && "Case range must fit in bit mask!");
+  const int BitWidth = TLI.getPointerTy(DL).getSizeInBits();
+  assert(TLI.rangeFitsInWord(Low, High, DL) &&
+         "Case range must fit in bit mask!");
 
   // Check if the clusters cover a contiguous range such that no value in the
   // range will jump to the default statement.
@@ -8999,7 +8936,9 @@ void SelectionDAGBuilder::findBitTestClu
 
   // If target does not have legal shift left, do not emit bit tests at all.
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  EVT PTy = TLI.getPointerTy(DAG.getDataLayout());
+  const DataLayout &DL = DAG.getDataLayout();
+
+  EVT PTy = TLI.getPointerTy(DL);
   if (!TLI.isOperationLegal(ISD::SHL, PTy))
     return;
 
@@ -9030,8 +8969,8 @@ void SelectionDAGBuilder::findBitTestClu
       // Try building a partition from Clusters[i..j].
 
       // Check the range.
-      if (!rangeFitsInWord(Clusters[i].Low->getValue(),
-                           Clusters[j].High->getValue()))
+      if (!TLI.rangeFitsInWord(Clusters[i].Low->getValue(),
+                               Clusters[j].High->getValue(), DL))
         continue;
 
       // Check nbr of destinations and cluster types.

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h?rev=301649&r1=301648&r2=301649&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h Fri Apr 28 11:04:03 2017
@@ -304,10 +304,13 @@ private:
     BranchProbability DefaultProb;
   };
 
-  /// Check whether a range of clusters is dense enough for a jump table.
-  bool isDense(const CaseClusterVector &Clusters,
-               const SmallVectorImpl<unsigned> &TotalCases,
-               unsigned First, unsigned Last, unsigned MinDensity) const;
+  /// Return the range of value in [First..Last].
+  uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
+                             unsigned Last) const;
+
+  /// Return the number of cases in [First..Last].
+  uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
+                                unsigned First, unsigned Last) const;
 
   /// Build a jump table cluster from Clusters[First..Last]. Returns false if it
   /// decides it's not a good idea.
@@ -319,14 +322,6 @@ private:
   void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
                       MachineBasicBlock *DefaultMBB);
 
-  /// Check whether the range [Low,High] fits in a machine word.
-  bool rangeFitsInWord(const APInt &Low, const APInt &High);
-
-  /// Check whether these clusters are suitable for lowering with bit tests based
-  /// on the number of destinations, comparison metric, and range.
-  bool isSuitableForBitTests(unsigned NumDests, unsigned NumCmps,
-                             const APInt &Low, const APInt &High);
-
   /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
   /// decides it's not a good idea.
   bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,

Modified: llvm/trunk/lib/CodeGen/TargetLoweringBase.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/TargetLoweringBase.cpp?rev=301649&r1=301648&r2=301649&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/TargetLoweringBase.cpp (original)
+++ llvm/trunk/lib/CodeGen/TargetLoweringBase.cpp Fri Apr 28 11:04:03 2017
@@ -53,6 +53,18 @@ static cl::opt<unsigned> MaximumJumpTabl
   ("max-jump-table-size", cl::init(0), cl::Hidden,
    cl::desc("Set maximum size of jump tables; zero for no limit."));
 
+/// Minimum jump table density for normal functions.
+static cl::opt<unsigned>
+    JumpTableDensity("jump-table-density", cl::init(10), cl::Hidden,
+                     cl::desc("Minimum density for building a jump table in "
+                              "a normal function"));
+
+/// Minimum jump table density for -Os or -Oz functions.
+static cl::opt<unsigned> OptsizeJumpTableDensity(
+    "optsize-jump-table-density", cl::init(40), cl::Hidden,
+    cl::desc("Minimum density for building a jump table in "
+             "an optsize function"));
+
 // Although this default value is arbitrary, it is not random. It is assumed
 // that a condition that evaluates the same way by a higher percentage than this
 // is best represented as control flow. Therefore, the default value N should be
@@ -1901,6 +1913,10 @@ void TargetLoweringBase::setMinimumJumpT
   MinimumJumpTableEntries = Val;
 }
 
+unsigned TargetLoweringBase::getMinimumJumpTableDensity(bool OptForSize) const {
+  return OptForSize ? OptsizeJumpTableDensity : JumpTableDensity;
+}
+
 unsigned TargetLoweringBase::getMaximumJumpTableSize() const {
   return MaximumJumpTableSize;
 }

Added: llvm/trunk/test/Transforms/Inline/AArch64/switch.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/Inline/AArch64/switch.ll?rev=301649&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/Inline/AArch64/switch.ll (added)
+++ llvm/trunk/test/Transforms/Inline/AArch64/switch.ll Fri Apr 28 11:04:03 2017
@@ -0,0 +1,123 @@
+; RUN: opt < %s -inline -inline-threshold=20 -S -mtriple=aarch64-none-linux -inline-generic-switch-cost=true | FileCheck %s
+; RUN: opt < %s -passes='cgscc(inline)' -inline-threshold=20 -S -mtriple=aarch64-none-linux -inline-generic-switch-cost=true | FileCheck %s
+
+define i32 @callee_range(i32 %a, i32* %P) {
+  switch i32 %a, label %sw.default [
+    i32 0, label %sw.bb0
+    i32 1000, label %sw.bb1
+    i32 2000, label %sw.bb1
+    i32 3000, label %sw.bb1
+    i32 4000, label %sw.bb1
+    i32 5000, label %sw.bb1
+    i32 6000, label %sw.bb1
+    i32 7000, label %sw.bb1
+    i32 8000, label %sw.bb1
+    i32 9000, label %sw.bb1
+  ]
+
+sw.default:
+  store volatile i32 %a, i32* %P
+  br label %return
+sw.bb0:
+  store volatile i32 %a, i32* %P
+  br label %return
+sw.bb1:
+  store volatile i32 %a, i32* %P
+  br label %return
+return:
+  ret i32 42
+}
+
+define i32 @caller_range(i32 %a, i32* %P) {
+; CHECK-LABEL: @caller_range(
+; CHECK: call i32 @callee_range
+  %r = call i32 @callee_range(i32 %a, i32* %P)
+  ret i32 %r
+}
+
+define i32 @callee_bittest(i32 %a, i32* %P) {
+  switch i32 %a, label %sw.default [
+    i32 0, label %sw.bb0
+    i32 1, label %sw.bb1
+    i32 2, label %sw.bb2
+    i32 3, label %sw.bb0
+    i32 4, label %sw.bb1
+    i32 5, label %sw.bb2
+    i32 6, label %sw.bb0
+    i32 7, label %sw.bb1
+    i32 8, label %sw.bb2
+  ]
+
+sw.default:
+  store volatile i32 %a, i32* %P
+  br label %return
+
+sw.bb0:
+  store volatile i32 %a, i32* %P
+  br label %return
+
+sw.bb1:
+  store volatile i32 %a, i32* %P
+  br label %return
+
+sw.bb2:
+  br label %return
+
+return:
+  ret i32 42
+}
+
+
+define i32 @caller_bittest(i32 %a, i32* %P) {
+; CHECK-LABEL: @caller_bittest(
+; CHECK-NOT: call i32 @callee_bittest
+  %r= call i32 @callee_bittest(i32 %a, i32* %P)
+  ret i32 %r
+}
+
+define i32 @callee_jumptable(i32 %a, i32* %P) {
+  switch i32 %a, label %sw.default [
+    i32 1001, label %sw.bb101
+    i32 1002, label %sw.bb102
+    i32 1003, label %sw.bb103
+    i32 1004, label %sw.bb104
+    i32 1005, label %sw.bb101
+    i32 1006, label %sw.bb102
+    i32 1007, label %sw.bb103
+    i32 1008, label %sw.bb104
+    i32 1009, label %sw.bb101
+    i32 1010, label %sw.bb102
+    i32 1011, label %sw.bb103
+    i32 1012, label %sw.bb104
+ ]
+
+sw.default:
+  br label %return
+
+sw.bb101:
+  store volatile i32 %a, i32* %P
+  br label %return
+
+sw.bb102:
+  store volatile i32 %a, i32* %P
+  br label %return
+
+sw.bb103:
+  store volatile i32 %a, i32* %P
+  br label %return
+
+sw.bb104:
+  store volatile i32 %a, i32* %P
+  br label %return
+
+return:
+  ret i32 42
+}
+
+define i32 @caller_jumptable(i32 %a, i32 %b, i32* %P) {
+; CHECK-LABEL: @caller_jumptable(
+; CHECK: call i32 @callee_jumptable
+  %r = call i32 @callee_jumptable(i32 %b, i32* %P)
+  ret i32 %r
+}
+




More information about the llvm-commits mailing list