[llvm] [SimplifyCFG] Improve linear mapping in switch lookup tables (PR #67881)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 23 09:54:40 PDT 2023


================
@@ -6139,46 +6141,80 @@ SwitchLookupTable::SwitchLookupTable(
   // Check if we can derive the value with a linear transformation from the
   // table index.
   if (isa<IntegerType>(ValueType)) {
-    bool LinearMappingPossible = true;
-    APInt PrevVal;
-    APInt DistToPrev;
-    // When linear map is monotonic and signed overflow doesn't happen on
-    // maximum index, we can attach nsw on Add and Mul.
-    bool NonMonotonic = false;
-    assert(TableSize >= 2 && "Should be a SingleValue table.");
-    // Check if there is the same distance between two consecutive values.
-    for (uint64_t I = 0; I < TableSize; ++I) {
-      ConstantInt *ConstVal = dyn_cast<ConstantInt>(TableContents[I]);
-      if (!ConstVal) {
-        // This is an undef. We could deal with it, but undefs in lookup tables
-        // are very seldom. It's probably not worth the additional complexity.
-        LinearMappingPossible = false;
-        break;
-      }
-      const APInt &Val = ConstVal->getValue();
-      if (I != 0) {
-        APInt Dist = Val - PrevVal;
-        if (I == 1) {
-          DistToPrev = Dist;
-        } else if (Dist != DistToPrev) {
-          LinearMappingPossible = false;
-          break;
+    auto MatchLinearMapping = [&](unsigned *LowBits) {
+      APInt PrevVal;
+      APInt DistToPrev;
+      // When linear map is monotonic and signed overflow doesn't happen on
+      // maximum index, we can attach nsw on Add and Mul.
+      bool NonMonotonic = false;
+      assert(TableSize >= 2 && "Should be a SingleValue table.");
+      // Check if there is the same distance between two consecutive values.
+      for (uint64_t I = 0; I < TableSize; ++I) {
+        ConstantInt *ConstVal = dyn_cast<ConstantInt>(TableContents[I]);
+        if (!ConstVal) {
+          // This is an undef. We could deal with it, but undefs in lookup
+          // tables are very seldom. It's probably not worth the additional
+          // complexity.
+          return false;
         }
-        NonMonotonic |=
-            Dist.isStrictlyPositive() ? Val.sle(PrevVal) : Val.sgt(PrevVal);
+        const APInt &Val = ConstVal->getValue();
+        if (I != 0) {
+          APInt Dist = Val - PrevVal;
+          if (LowBits)
+            Dist = Dist.getLoBits(*LowBits);
+          if (I == 1)
+            DistToPrev = Dist;
+          else if (Dist != DistToPrev)
+            return false;
+          if (!LowBits)
+            NonMonotonic |=
+                Dist.isStrictlyPositive() ? Val.sle(PrevVal) : Val.sgt(PrevVal);
+        }
+        PrevVal = Val;
       }
-      PrevVal = Val;
-    }
-    if (LinearMappingPossible) {
+
       LinearOffset = cast<ConstantInt>(TableContents[0]);
       LinearMultiplier = ConstantInt::get(M.getContext(), DistToPrev);
-      bool MayWrap = false;
-      APInt M = LinearMultiplier->getValue();
-      (void)M.smul_ov(APInt(M.getBitWidth(), TableSize - 1), MayWrap);
-      LinearMapValWrapped = NonMonotonic || MayWrap;
+      if (LowBits)
+        LinearMapValWrapped = true;
+      else {
+        bool MayWrap = false;
+        APInt M = LinearMultiplier->getValue();
+        (void)M.smul_ov(APInt(M.getBitWidth(), TableSize - 1), MayWrap);
+        LinearMapValWrapped = NonMonotonic || MayWrap;
+      }
       Kind = LinearMapKind;
       ++NumLinearMaps;
+      return true;
+    };
+
+    if (MatchLinearMapping(/* LowBits */ nullptr))
       return;
+    // Try matching highbits | ((offset + index * multiplier) & lowbits_mask)
+    APInt CommonOnes = APInt::getAllOnes(ValueType->getScalarSizeInBits());
+    APInt CommonZeros = APInt::getAllOnes(ValueType->getScalarSizeInBits());
+    bool IsCommonBitsValid = true;
+    for (uint64_t I = 0; I < TableSize; ++I) {
+      ConstantInt *ConstVal = dyn_cast<ConstantInt>(TableContents[I]);
+      if (!ConstVal) {
+        // ignore undefs
+        IsCommonBitsValid = false;
+        break;
+      }
+      const APInt &Val = ConstVal->getValue();
+      CommonOnes &= Val;
+      CommonZeros &= ~Val;
+    }
+    if (IsCommonBitsValid) {
+      unsigned CommonHighBits = (CommonOnes | CommonZeros).countLeadingOnes();
+      unsigned LowBits = CommonOnes.getBitWidth() - CommonHighBits;
+      assert(LowBits > 0 && "Should be a SingleValue table.");
+      if (CommonHighBits > 0 && MatchLinearMapping(&LowBits)) {
+        LinearMapValMaskedBits = LowBits;
+        LinearMapValHighBits = CommonOnes;
+        LinearMapValHighBits.clearLowBits(LowBits);
----------------
goldsteinn wrote:

This should be gratuitous based on the LowBits = BitWidth - (CommonOnes | CommoneZeros).clz

https://github.com/llvm/llvm-project/pull/67881


More information about the llvm-commits mailing list