[llvm] 9b70a28 - [Transform] Rewrite LowerSwitch using APInt

Peter Rong via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 24 20:22:12 PST 2023


Author: Peter Rong
Date: 2023-01-24T20:22:06-08:00
New Revision: 9b70a28e0d767f99bdc778356e81b4d072f59819

URL: https://github.com/llvm/llvm-project/commit/9b70a28e0d767f99bdc778356e81b4d072f59819
DIFF: https://github.com/llvm/llvm-project/commit/9b70a28e0d767f99bdc778356e81b4d072f59819.diff

LOG: [Transform] Rewrite LowerSwitch using APInt

This rewrite fixes https://github.com/llvm/llvm-project/issues/59316.

Previously LowerSwitch uses int64_t, which will crash on case branches using integers with more than 64 bits.
Using APInt fixes this problem. This patch also includes a test

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D140747

Added: 
    

Modified: 
    llvm/lib/Transforms/Utils/LowerSwitch.cpp
    llvm/test/Transforms/LowerSwitch/pr59316.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/llvm/lib/Transforms/Utils/LowerSwitch.cpp
index 26aebdfff6408..227de425ff855 100644
--- a/llvm/lib/Transforms/Utils/LowerSwitch.cpp
+++ b/llvm/lib/Transforms/Utils/LowerSwitch.cpp
@@ -370,7 +370,9 @@ void ProcessSwitchInst(SwitchInst *SI,
   const unsigned NumSimpleCases = Clusterify(Cases, SI);
   IntegerType *IT = cast<IntegerType>(SI->getCondition()->getType());
   const unsigned BitWidth = IT->getBitWidth();
-  APInt SignedZero(BitWidth, 0);
+  // Explictly use higher precision to prevent unsigned overflow where
+  // `UnsignedMax - 0 + 1 == 0`
+  APInt UnsignedZero(BitWidth + 1, 0);
   APInt UnsignedMax = APInt::getMaxValue(BitWidth);
   LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size()
                     << ". Total non-default cases: " << NumSimpleCases
@@ -431,7 +433,7 @@ void ProcessSwitchInst(SwitchInst *SI,
 
   if (DefaultIsUnreachableFromSwitch) {
     DenseMap<BasicBlock *, APInt> Popularity;
-    APInt MaxPop(SignedZero);
+    APInt MaxPop(UnsignedZero);
     BasicBlock *PopSucc = nullptr;
 
     APInt SignedMax = APInt::getSignedMaxValue(BitWidth);
@@ -457,11 +459,11 @@ void ProcessSwitchInst(SwitchInst *SI,
       }
 
       // Count popularity.
-      APInt N = High - Low + 1;
-      assert(N.sge(SignedZero) && "Popularity shouldn't be negative.");
+      assert(High.sge(Low) && "Popularity shouldn't be negative.");
+      APInt N = High.sext(BitWidth + 1) - Low.sext(BitWidth + 1) + 1;
       // Explict insert to make sure the bitwidth of APInts match
-      APInt &Pop = Popularity.insert({I.BB, APInt(SignedZero)}).first->second;
-      if ((Pop += N).sgt(MaxPop)) {
+      APInt &Pop = Popularity.insert({I.BB, APInt(UnsignedZero)}).first->second;
+      if ((Pop += N).ugt(MaxPop)) {
         MaxPop = Pop;
         PopSucc = I.BB;
       }
@@ -486,8 +488,6 @@ void ProcessSwitchInst(SwitchInst *SI,
 
     // Use the most popular block as the new default, reducing the number of
     // cases.
-    assert(MaxPop.sgt(SignedZero) && PopSucc &&
-           "Max populartion shouldn't be negative.");
     Default = PopSucc;
     llvm::erase_if(Cases,
                    [PopSucc](const CaseRange &R) { return R.BB == PopSucc; });
@@ -498,8 +498,9 @@ void ProcessSwitchInst(SwitchInst *SI,
       SI->eraseFromParent();
       // As all the cases have been replaced with a single branch, only keep
       // one entry in the PHI nodes.
-      for (APInt I(SignedZero); I.slt(MaxPop - 1); ++I)
-        PopSucc->removePredecessor(OrigBlock);
+      if (!MaxPop.isZero())
+        for (APInt I(UnsignedZero); I.ult(MaxPop - 1); ++I)
+          PopSucc->removePredecessor(OrigBlock);
       return;
     }
 

diff  --git a/llvm/test/Transforms/LowerSwitch/pr59316.ll b/llvm/test/Transforms/LowerSwitch/pr59316.ll
index 2e4226c71ea7d..0616ace672965 100644
--- a/llvm/test/Transforms/LowerSwitch/pr59316.ll
+++ b/llvm/test/Transforms/LowerSwitch/pr59316.ll
@@ -62,3 +62,32 @@ BB:
 BB1:                                              ; preds = %BB
   unreachable
 }
+
+define void @f_i1() {
+entry:
+  switch i1 false, label %sw.bb [
+    i1 false, label %sw.bb12
+  ]
+
+sw.bb:                                            ; preds = %entry
+  unreachable
+
+sw.bb12:                                          ; preds = %entry
+  unreachable
+}
+
+define void @f_i2(i2 %cond) {
+entry:
+  switch i2 %cond, label %sw.bb [
+    i2 0, label %sw.bb12
+    i2 1, label %sw.bb12
+    i2 2, label %sw.bb12
+    i2 3, label %sw.bb12
+  ]
+
+sw.bb:                                            ; preds = %entry
+  unreachable
+
+sw.bb12:                                          ; preds = %entry
+  unreachable
+}
\ No newline at end of file


        


More information about the llvm-commits mailing list