[llvm] [SimplifyCFG] Find the minimal table considering overflow in `switchToLookupTable` (PR #67885)

via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 30 23:10:11 PDT 2023


================
@@ -6519,17 +6518,60 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
   SmallDenseMap<PHINode *, Type *> ResultTypes;
   SmallVector<PHINode *, 4> PHIs;
 
-  for (SwitchInst::CaseIt E = SI->case_end(); CI != E; ++CI) {
-    ConstantInt *CaseVal = CI->getCaseValue();
-    if (CaseVal->getValue().slt(MinCaseVal->getValue()))
-      MinCaseVal = CaseVal;
-    if (CaseVal->getValue().sgt(MaxCaseVal->getValue()))
-      MaxCaseVal = CaseVal;
+  SmallVector<ConstantInt *, 8> CaseVals;
+  for (auto CI : SI->cases()) {
+    ConstantInt *CaseVal = CI.getCaseValue();
+    CaseVals.push_back(CaseVal);
+  }
+
+  // We want to find a range of indexes that will create the minimal table.
+  // We can treat all possible index values as a circle. For example, the i8 is
+  // [-128, -1] and [0, 127]. After that find the minimal range from this circle
+  // that can cover all exist values. First, create an incrementing sequence.
+  llvm::sort(CaseVals, [](const ConstantInt *A, const ConstantInt *B) {
+    return A->getValue().slt(B->getValue());
+  });
+  auto *CaseValIter = CaseVals.begin();
+  // We start by using the begin and end as the minimal table.
+  ConstantInt *BeginCaseVal = *CaseValIter;
+  ConstantInt *EndCaseVal = *CaseVals.rbegin();
+  bool RangeOverflow = false;
+  uint64_t MinTableSize = EndCaseVal->getValue()
+                              .ssub_ov(BeginCaseVal->getValue(), RangeOverflow)
+                              .getLimitedValue() +
+                          1;
+  // If there is no overflow, then this must be the minimal table.
+  if (RangeOverflow) {
+    auto MaxValue = APInt::getMaxValue(BeginCaseVal->getBitWidth());
+    while (CaseValIter != CaseVals.end()) {
+      auto *CurrentCaseVal = *CaseValIter++;
+      if (CaseValIter == CaseVals.end()) {
+        break;
+      }
+      ConstantInt *NextCaseVal = *CaseValIter;
+      auto NextVal = NextCaseVal->getValue();
+      auto CurVal = CurrentCaseVal->getValue();
+      uint64_t RequireTableSize =
+          (MaxValue - (NextVal - CurVal) + 1).getLimitedValue() + 1;
----------------
DianQK wrote:

Yes, that might be easier to understand.

https://github.com/llvm/llvm-project/pull/67885


More information about the llvm-commits mailing list