[llvm] 1db51d8 - [Transform] Rewrite LowerSwitch using APInt

Peter Rong via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 5 14:30:48 PST 2023


Author: Peter Rong
Date: 2023-01-05T14:30:42-08:00
New Revision: 1db51d8eb2d220a4f0000555ada310990098cf5b

URL: https://github.com/llvm/llvm-project/commit/1db51d8eb2d220a4f0000555ada310990098cf5b
DIFF: https://github.com/llvm/llvm-project/commit/1db51d8eb2d220a4f0000555ada310990098cf5b.diff

LOG: [Transform] Rewrite LowerSwitch using APInt

This rewrite fixes https://github.com/llvm/llvm-project/issues/59316.

Previously LowerSwitch uses int64_t, which will crash on case branches using integers with more than 64 bits.
Using APInt fixes this problem. This patch also includes a test

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D140747

Added: 
    llvm/test/Transforms/LowerSwitch/pr59316.ll

Modified: 
    llvm/lib/Transforms/Utils/LowerSwitch.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/llvm/lib/Transforms/Utils/LowerSwitch.cpp
index 9e3095fa291f8..26aebdfff6408 100644
--- a/llvm/lib/Transforms/Utils/LowerSwitch.cpp
+++ b/llvm/lib/Transforms/Utils/LowerSwitch.cpp
@@ -52,7 +52,7 @@ using namespace llvm;
 namespace {
 
 struct IntRange {
-  int64_t Low, High;
+  APInt Low, High;
 };
 
 } // end anonymous namespace
@@ -66,8 +66,8 @@ bool IsInRanges(const IntRange &R, const std::vector<IntRange> &Ranges) {
   // then check if the Low field is <= R.Low. If so, we
   // have a Range that covers R.
   auto I = llvm::lower_bound(
-      Ranges, R, [](IntRange A, IntRange B) { return A.High < B.High; });
-  return I != Ranges.end() && I->Low <= R.Low;
+      Ranges, R, [](IntRange A, IntRange B) { return A.High.slt(B.High); });
+  return I != Ranges.end() && I->Low.sle(R.Low);
 }
 
 struct CaseRange {
@@ -116,15 +116,14 @@ raw_ostream &operator<<(raw_ostream &O, const CaseVector &C) {
 /// 2) Removed if subsequent incoming values now share the same case, i.e.,
 /// multiple outcome edges are condensed into one. This is necessary to keep the
 /// number of phi values equal to the number of branches to SuccBB.
-void FixPhis(
-    BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB,
-    const unsigned NumMergedCases = std::numeric_limits<unsigned>::max()) {
+void FixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB,
+             const APInt &NumMergedCases) {
   for (auto &I : SuccBB->phis()) {
     PHINode *PN = cast<PHINode>(&I);
 
     // Only update the first occurrence if NewBB exists.
     unsigned Idx = 0, E = PN->getNumIncomingValues();
-    unsigned LocalNumMergedCases = NumMergedCases;
+    APInt LocalNumMergedCases = NumMergedCases;
     for (; Idx != E && NewBB; ++Idx) {
       if (PN->getIncomingBlock(Idx) == OrigBB) {
         PN->setIncomingBlock(Idx, NewBB);
@@ -139,10 +138,10 @@ void FixPhis(
     // Remove additional occurrences coming from condensed cases and keep the
     // number of incoming values equal to the number of branches to SuccBB.
     SmallVector<unsigned, 8> Indices;
-    for (; LocalNumMergedCases > 0 && Idx < E; ++Idx)
+    for (; LocalNumMergedCases.ugt(0) && Idx < E; ++Idx)
       if (PN->getIncomingBlock(Idx) == OrigBB) {
         Indices.push_back(Idx);
-        LocalNumMergedCases--;
+        LocalNumMergedCases -= 1;
       }
     // Remove incoming values in the reverse order to prevent invalidating
     // *successive* index.
@@ -209,8 +208,8 @@ BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound,
   for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) {
     PHINode *PN = cast<PHINode>(I);
     // Remove all but one incoming entries from the cluster
-    uint64_t Range = Leaf.High->getSExtValue() - Leaf.Low->getSExtValue();
-    for (uint64_t j = 0; j < Range; ++j) {
+    APInt Range = Leaf.High->getValue() - Leaf.Low->getValue();
+    for (APInt j(Range.getBitWidth(), 0, true); j.slt(Range); ++j) {
       PN->removeIncomingValue(OrigBlock);
     }
 
@@ -241,8 +240,7 @@ BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound,
     // emitting the code that checks if the value actually falls in the range
     // because the bounds already tell us so.
     if (Begin->Low == LowerBound && Begin->High == UpperBound) {
-      unsigned NumMergedCases = 0;
-      NumMergedCases = UpperBound->getSExtValue() - LowerBound->getSExtValue();
+      APInt NumMergedCases = UpperBound->getValue() - LowerBound->getValue();
       FixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases);
       return Begin->BB;
     }
@@ -273,17 +271,17 @@ BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound,
 
   if (!UnreachableRanges.empty()) {
     // Check if the gap between LHS's highest and NewLowerBound is unreachable.
-    int64_t GapLow = LHS.back().High->getSExtValue() + 1;
-    int64_t GapHigh = NewLowerBound->getSExtValue() - 1;
+    APInt GapLow = LHS.back().High->getValue() + 1;
+    APInt GapHigh = NewLowerBound->getValue() - 1;
     IntRange Gap = {GapLow, GapHigh};
-    if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges))
+    if (GapHigh.sge(GapLow) && IsInRanges(Gap, UnreachableRanges))
       NewUpperBound = LHS.back().High;
   }
 
-  LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getSExtValue() << ", "
-                    << NewUpperBound->getSExtValue() << "]\n"
-                    << "RHS Bounds ==> [" << NewLowerBound->getSExtValue()
-                    << ", " << UpperBound->getSExtValue() << "]\n");
+  LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getValue() << ", "
+                    << NewUpperBound->getValue() << "]\n"
+                    << "RHS Bounds ==> [" << NewLowerBound->getValue() << ", "
+                    << UpperBound->getValue() << "]\n");
 
   // Create a new node that checks if the value is < pivot. Go to the
   // left branch if it is and right branch if not.
@@ -327,14 +325,15 @@ unsigned Clusterify(CaseVector &Cases, SwitchInst *SI) {
   if (Cases.size() >= 2) {
     CaseItr I = Cases.begin();
     for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) {
-      int64_t nextValue = J->Low->getSExtValue();
-      int64_t currentValue = I->High->getSExtValue();
+      const APInt &nextValue = J->Low->getValue();
+      const APInt &currentValue = I->High->getValue();
       BasicBlock *nextBB = J->BB;
       BasicBlock *currentBB = I->BB;
 
       // If the two neighboring cases go to the same destination, merge them
       // into a single case.
-      assert(nextValue > currentValue && "Cases should be strictly ascending");
+      assert(nextValue.sgt(currentValue) &&
+             "Cases should be strictly ascending");
       if ((nextValue == currentValue + 1) && (currentBB == nextBB)) {
         I->High = J->High;
         // FIXME: Combine branch weights.
@@ -369,6 +368,10 @@ void ProcessSwitchInst(SwitchInst *SI,
   // Prepare cases vector.
   CaseVector Cases;
   const unsigned NumSimpleCases = Clusterify(Cases, SI);
+  IntegerType *IT = cast<IntegerType>(SI->getCondition()->getType());
+  const unsigned BitWidth = IT->getBitWidth();
+  APInt SignedZero(BitWidth, 0);
+  APInt UnsignedMax = APInt::getMaxValue(BitWidth);
   LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size()
                     << ". Total non-default cases: " << NumSimpleCases
                     << "\nCase clusters: " << Cases << "\n");
@@ -377,7 +380,7 @@ void ProcessSwitchInst(SwitchInst *SI,
   if (Cases.empty()) {
     BranchInst::Create(Default, OrigBlock);
     // Remove all the references from Default's PHIs to OrigBlock, but one.
-    FixPhis(Default, OrigBlock, OrigBlock);
+    FixPhis(Default, OrigBlock, OrigBlock, UnsignedMax);
     SI->eraseFromParent();
     return;
   }
@@ -414,8 +417,8 @@ void ProcessSwitchInst(SwitchInst *SI,
     // the unlikely event that some of them survived, we just conservatively
     // maintain the invariant that all the cases lie between the bounds. This
     // may, however, still render the default case effectively unreachable.
-    APInt Low = Cases.front().Low->getValue();
-    APInt High = Cases.back().High->getValue();
+    const APInt &Low = Cases.front().Low->getValue();
+    const APInt &High = Cases.back().High->getValue();
     APInt Min = APIntOps::smin(ValRange.getSignedMin(), Low);
     APInt Max = APIntOps::smax(ValRange.getSignedMax(), High);
 
@@ -427,35 +430,38 @@ void ProcessSwitchInst(SwitchInst *SI,
   std::vector<IntRange> UnreachableRanges;
 
   if (DefaultIsUnreachableFromSwitch) {
-    DenseMap<BasicBlock *, unsigned> Popularity;
-    unsigned MaxPop = 0;
+    DenseMap<BasicBlock *, APInt> Popularity;
+    APInt MaxPop(SignedZero);
     BasicBlock *PopSucc = nullptr;
 
-    IntRange R = {std::numeric_limits<int64_t>::min(),
-                  std::numeric_limits<int64_t>::max()};
+    APInt SignedMax = APInt::getSignedMaxValue(BitWidth);
+    APInt SignedMin = APInt::getSignedMinValue(BitWidth);
+    IntRange R = {SignedMin, SignedMax};
     UnreachableRanges.push_back(R);
     for (const auto &I : Cases) {
-      int64_t Low = I.Low->getSExtValue();
-      int64_t High = I.High->getSExtValue();
+      const APInt &Low = I.Low->getValue();
+      const APInt &High = I.High->getValue();
 
       IntRange &LastRange = UnreachableRanges.back();
-      if (LastRange.Low == Low) {
+      if (LastRange.Low.eq(Low)) {
         // There is nothing left of the previous range.
         UnreachableRanges.pop_back();
       } else {
         // Terminate the previous range.
-        assert(Low > LastRange.Low);
+        assert(Low.sgt(LastRange.Low));
         LastRange.High = Low - 1;
       }
-      if (High != std::numeric_limits<int64_t>::max()) {
-        IntRange R = {High + 1, std::numeric_limits<int64_t>::max()};
+      if (High.ne(SignedMax)) {
+        IntRange R = {High + 1, SignedMax};
         UnreachableRanges.push_back(R);
       }
 
       // Count popularity.
-      int64_t N = High - Low + 1;
-      unsigned &Pop = Popularity[I.BB];
-      if ((Pop += N) > MaxPop) {
+      APInt N = High - Low + 1;
+      assert(N.sge(SignedZero) && "Popularity shouldn't be negative.");
+      // Explict insert to make sure the bitwidth of APInts match
+      APInt &Pop = Popularity.insert({I.BB, APInt(SignedZero)}).first->second;
+      if ((Pop += N).sgt(MaxPop)) {
         MaxPop = Pop;
         PopSucc = I.BB;
       }
@@ -464,10 +470,10 @@ void ProcessSwitchInst(SwitchInst *SI,
     /* UnreachableRanges should be sorted and the ranges non-adjacent. */
     for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end();
          I != E; ++I) {
-      assert(I->Low <= I->High);
+      assert(I->Low.sle(I->High));
       auto Next = I + 1;
       if (Next != E) {
-        assert(Next->Low > I->High);
+        assert(Next->Low.sgt(I->High));
       }
     }
 #endif
@@ -480,7 +486,8 @@ void ProcessSwitchInst(SwitchInst *SI,
 
     // Use the most popular block as the new default, reducing the number of
     // cases.
-    assert(MaxPop > 0 && PopSucc);
+    assert(MaxPop.sgt(SignedZero) && PopSucc &&
+           "Max populartion shouldn't be negative.");
     Default = PopSucc;
     llvm::erase_if(Cases,
                    [PopSucc](const CaseRange &R) { return R.BB == PopSucc; });
@@ -491,7 +498,7 @@ void ProcessSwitchInst(SwitchInst *SI,
       SI->eraseFromParent();
       // As all the cases have been replaced with a single branch, only keep
       // one entry in the PHI nodes.
-      for (unsigned I = 0; I < (MaxPop - 1); ++I)
+      for (APInt I(SignedZero); I.slt(MaxPop - 1); ++I)
         PopSucc->removePredecessor(OrigBlock);
       return;
     }
@@ -512,7 +519,7 @@ void ProcessSwitchInst(SwitchInst *SI,
   // that SwitchBlock is the same as Default, under which the PHIs in Default
   // are fixed inside SwitchConvert().
   if (SwitchBlock != Default)
-    FixPhis(Default, OrigBlock, nullptr);
+    FixPhis(Default, OrigBlock, nullptr, UnsignedMax);
 
   // Branch to our shiny new if-then stuff...
   BranchInst::Create(SwitchBlock, OrigBlock);

diff  --git a/llvm/test/Transforms/LowerSwitch/pr59316.ll b/llvm/test/Transforms/LowerSwitch/pr59316.ll
new file mode 100644
index 0000000000000..2e4226c71ea7d
--- /dev/null
+++ b/llvm/test/Transforms/LowerSwitch/pr59316.ll
@@ -0,0 +1,64 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=lowerswitch -S | FileCheck %s
+
+define i64 @f(i1 %bool, i128 %i128) {
+; CHECK-LABEL: @f(
+; CHECK-NEXT:  BB:
+; CHECK-NEXT:    br label [[NODEBLOCK1:%.*]]
+; CHECK:       NodeBlock1:
+; CHECK-NEXT:    [[PIVOT2:%.*]] = icmp slt i128 [[I128:%.*]], 16201310291018008446
+; CHECK-NEXT:    br i1 [[PIVOT2]], label [[LEAFBLOCK:%.*]], label [[NODEBLOCK:%.*]]
+; CHECK:       NodeBlock:
+; CHECK-NEXT:    [[PIVOT:%.*]] = icmp slt i128 [[I128]], 16201310291018008447
+; CHECK-NEXT:    br i1 [[PIVOT]], label [[SW_C3:%.*]], label [[SW_C2:%.*]]
+; CHECK:       LeafBlock:
+; CHECK-NEXT:    [[SWITCHLEAF:%.*]] = icmp eq i128 [[I128]], 16201310291018008445
+; CHECK-NEXT:    br i1 [[SWITCHLEAF]], label [[SW_C4:%.*]], label [[SW_C1:%.*]]
+; CHECK:       BB1:
+; CHECK-NEXT:    unreachable
+; CHECK:       SW_C1:
+; CHECK-NEXT:    br i1 [[BOOL:%.*]], label [[BB1:%.*]], label [[SW_C1]]
+; CHECK:       SW_C2:
+; CHECK-NEXT:    ret i64 0
+; CHECK:       SW_C3:
+; CHECK-NEXT:    ret i64 1
+; CHECK:       SW_C4:
+; CHECK-NEXT:    ret i64 2
+;
+BB:
+  switch i128 %i128, label %BB1 [
+  i128 627, label %SW_C1
+  i128 16201310291018008447, label %SW_C2
+  i128 16201310291018008446, label %SW_C3
+  i128 16201310291018008445, label %SW_C4
+  ]
+
+BB1:                                              ; preds = %SW_C1, %BB
+  unreachable
+
+SW_C1:                                            ; preds = %SW_C1, %BB
+  br i1 %bool, label %BB1, label %SW_C1
+
+SW_C2:                                            ; preds = %BB
+  ret i64 0
+
+SW_C3:                                            ; preds = %BB
+  ret i64 1
+
+SW_C4:                                            ; preds = %BB
+  ret i64 2
+}
+
+define i64 @f_empty(i1 %bool, i128 %i128) {
+; CHECK-LABEL: @f_empty(
+; CHECK-NEXT:  BB:
+; CHECK-NEXT:    br label [[BB1:%.*]]
+; CHECK:       BB1:
+; CHECK-NEXT:    unreachable
+;
+BB:
+  switch i128 %i128, label %BB1 []
+
+BB1:                                              ; preds = %BB
+  unreachable
+}


        


More information about the llvm-commits mailing list