[llvm] [SimplifyCFG] Improve linear mapping in switch lookup tables (PR #67881)

via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 30 05:06:57 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

<details>
<summary>Changes</summary>

This patch handles the pattern `highbits | ((offset + index * multiplier) & lowbits_mask)`.

Example:
```
int f1(int x) {
  switch (x) {
    case 0: return 255;
    case 1: return 0;
    case 2: return 1;
    case 3: return 2;
    default: __builtin_unreachable();
  }
}
```
generates:
```
define i32 @<!-- -->f1(i32 %x) {
entry:
  %switch.offset = add i32 %x, 255
  %switch.masked = and i32 %switch.offset, 255
  ret i32 %switch.masked
}
```

Fixes #<!-- -->67843.


---
Full diff: https://github.com/llvm/llvm-project/pull/67881.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+82-33) 
- (modified) llvm/test/Transforms/SimplifyCFG/X86/switch-table-bug.ll (+4-4) 
- (modified) llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll (+93) 


``````````diff
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 35fead111aa9666..ff323e43866b477 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -6082,6 +6082,8 @@ class SwitchLookupTable {
   ConstantInt *LinearOffset = nullptr;
   ConstantInt *LinearMultiplier = nullptr;
   bool LinearMapValWrapped = false;
+  unsigned LinearMapValMaskedBits = 0;
+  APInt LinearMapValHighBits;
 
   // For ArrayKind, this is the array.
   GlobalVariable *Array = nullptr;
@@ -6139,46 +6141,80 @@ SwitchLookupTable::SwitchLookupTable(
   // Check if we can derive the value with a linear transformation from the
   // table index.
   if (isa<IntegerType>(ValueType)) {
-    bool LinearMappingPossible = true;
-    APInt PrevVal;
-    APInt DistToPrev;
-    // When linear map is monotonic and signed overflow doesn't happen on
-    // maximum index, we can attach nsw on Add and Mul.
-    bool NonMonotonic = false;
-    assert(TableSize >= 2 && "Should be a SingleValue table.");
-    // Check if there is the same distance between two consecutive values.
-    for (uint64_t I = 0; I < TableSize; ++I) {
-      ConstantInt *ConstVal = dyn_cast<ConstantInt>(TableContents[I]);
-      if (!ConstVal) {
-        // This is an undef. We could deal with it, but undefs in lookup tables
-        // are very seldom. It's probably not worth the additional complexity.
-        LinearMappingPossible = false;
-        break;
-      }
-      const APInt &Val = ConstVal->getValue();
-      if (I != 0) {
-        APInt Dist = Val - PrevVal;
-        if (I == 1) {
-          DistToPrev = Dist;
-        } else if (Dist != DistToPrev) {
-          LinearMappingPossible = false;
-          break;
+    auto MatchLinearMapping = [&](unsigned *LowBits) {
+      APInt PrevVal;
+      APInt DistToPrev;
+      // When linear map is monotonic and signed overflow doesn't happen on
+      // maximum index, we can attach nsw on Add and Mul.
+      bool NonMonotonic = false;
+      assert(TableSize >= 2 && "Should be a SingleValue table.");
+      // Check if there is the same distance between two consecutive values.
+      for (uint64_t I = 0; I < TableSize; ++I) {
+        ConstantInt *ConstVal = dyn_cast<ConstantInt>(TableContents[I]);
+        if (!ConstVal) {
+          // This is an undef. We could deal with it, but undefs in lookup
+          // tables are very seldom. It's probably not worth the additional
+          // complexity.
+          return false;
         }
-        NonMonotonic |=
-            Dist.isStrictlyPositive() ? Val.sle(PrevVal) : Val.sgt(PrevVal);
+        const APInt &Val = ConstVal->getValue();
+        if (I != 0) {
+          APInt Dist = Val - PrevVal;
+          if (LowBits)
+            Dist = Dist.getLoBits(*LowBits);
+          if (I == 1)
+            DistToPrev = Dist;
+          else if (Dist != DistToPrev)
+            return false;
+          if (!LowBits)
+            NonMonotonic |=
+                Dist.isStrictlyPositive() ? Val.sle(PrevVal) : Val.sgt(PrevVal);
+        }
+        PrevVal = Val;
       }
-      PrevVal = Val;
-    }
-    if (LinearMappingPossible) {
+
       LinearOffset = cast<ConstantInt>(TableContents[0]);
       LinearMultiplier = ConstantInt::get(M.getContext(), DistToPrev);
-      bool MayWrap = false;
-      APInt M = LinearMultiplier->getValue();
-      (void)M.smul_ov(APInt(M.getBitWidth(), TableSize - 1), MayWrap);
-      LinearMapValWrapped = NonMonotonic || MayWrap;
+      if (LowBits)
+        LinearMapValWrapped = true;
+      else {
+        bool MayWrap = false;
+        APInt M = LinearMultiplier->getValue();
+        (void)M.smul_ov(APInt(M.getBitWidth(), TableSize - 1), MayWrap);
+        LinearMapValWrapped = NonMonotonic || MayWrap;
+      }
       Kind = LinearMapKind;
       ++NumLinearMaps;
+      return true;
+    };
+
+    if (MatchLinearMapping(/* LowBits */ nullptr))
       return;
+    // Try matching highbits | ((offset + index * multiplier) & lowbits_mask)
+    APInt CommonOnes = APInt::getAllOnes(ValueType->getScalarSizeInBits());
+    APInt CommonZeros = APInt::getAllOnes(ValueType->getScalarSizeInBits());
+    bool IsCommonBitsValid = true;
+    for (uint64_t I = 0; I < TableSize; ++I) {
+      ConstantInt *ConstVal = dyn_cast<ConstantInt>(TableContents[I]);
+      if (!ConstVal) {
+        // ignore undefs
+        IsCommonBitsValid = false;
+        break;
+      }
+      const APInt &Val = ConstVal->getValue();
+      CommonOnes &= Val;
+      CommonZeros &= ~Val;
+    }
+    if (IsCommonBitsValid) {
+      unsigned CommonHighBits = (CommonOnes | CommonZeros).countLeadingOnes();
+      unsigned LowBits = CommonOnes.getBitWidth() - CommonHighBits;
+      assert(LowBits > 0 && "Should be a SingleValue table.");
+      if (CommonHighBits > 0 && MatchLinearMapping(&LowBits)) {
+        LinearMapValMaskedBits = LowBits;
+        LinearMapValHighBits = CommonOnes;
+        LinearMapValHighBits.clearLowBits(LowBits);
+        return;
+      }
     }
   }
 
@@ -6232,6 +6268,19 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) {
       Result = Builder.CreateAdd(Result, LinearOffset, "switch.offset",
                                  /*HasNUW = */ false,
                                  /*HasNSW = */ !LinearMapValWrapped);
+
+    if (LinearMapValMaskedBits) {
+      Result = Builder.CreateAnd(
+          Result,
+          APInt::getLowBitsSet(
+              cast<IntegerType>(Result->getType())->getBitWidth(),
+              LinearMapValMaskedBits),
+          "switch.masked");
+      if (!LinearMapValHighBits.isZero())
+        Result = Builder.CreateOr(Result, LinearMapValHighBits,
+                                  "switch.with_high_bits");
+    }
+
     return Result;
   }
   case BitMapKind: {
diff --git a/llvm/test/Transforms/SimplifyCFG/X86/switch-table-bug.ll b/llvm/test/Transforms/SimplifyCFG/X86/switch-table-bug.ll
index 37001f4fba2aa84..f865387a24f338a 100644
--- a/llvm/test/Transforms/SimplifyCFG/X86/switch-table-bug.ll
+++ b/llvm/test/Transforms/SimplifyCFG/X86/switch-table-bug.ll
@@ -10,10 +10,10 @@ define i64 @_TFO6reduce1E5toRawfS0_FT_Si(i2) {
 ; CHECK-LABEL: @_TFO6reduce1E5toRawfS0_FT_Si(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[SWITCH_TABLEIDX:%.*]] = sub i2 [[TMP0:%.*]], -2
-; CHECK-NEXT:    [[SWITCH_TABLEIDX_ZEXT:%.*]] = zext i2 [[SWITCH_TABLEIDX]] to i3
-; CHECK-NEXT:    [[SWITCH_GEP:%.*]] = getelementptr inbounds [4 x i64], ptr @switch.table._TFO6reduce1E5toRawfS0_FT_Si, i32 0, i3 [[SWITCH_TABLEIDX_ZEXT]]
-; CHECK-NEXT:    [[SWITCH_LOAD:%.*]] = load i64, ptr [[SWITCH_GEP]], align 8
-; CHECK-NEXT:    ret i64 [[SWITCH_LOAD]]
+; CHECK-NEXT:    [[SWITCH_IDX_CAST:%.*]] = zext i2 [[SWITCH_TABLEIDX]] to i64
+; CHECK-NEXT:    [[SWITCH_OFFSET:%.*]] = add i64 [[SWITCH_IDX_CAST]], 2
+; CHECK-NEXT:    [[SWITCH_MASKED:%.*]] = and i64 [[SWITCH_OFFSET]], 3
+; CHECK-NEXT:    ret i64 [[SWITCH_MASKED]]
 ;
 entry:
   switch i2 %0, label %1 [
diff --git a/llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll b/llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll
index 3873f0c0ae0bbd5..e144452dff4e9de 100644
--- a/llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll
+++ b/llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll
@@ -28,6 +28,9 @@ target triple = "x86_64-unknown-linux-gnu"
 ; The table for @unreachable_case
 ; CHECK: @switch.table.unreachable_case = private unnamed_addr constant [9 x i32] [i32 0, i32 0, i32 0, i32 2, i32 -1, i32 1, i32 1, i32 1, i32 1], align 4
 
+; The table for @linearmap_masked_with_common_highbits_fail
+; CHECK: @switch.table.linearmap_masked_with_common_highbits_fail = private unnamed_addr constant [3 x i32] [i32 1023, i32 256, i32 257], align 4
+
 ; A simple int-to-int selection switch.
 ; It is dense enough to be replaced by table lookup.
 ; The result is directly by a ret from an otherwise empty bb,
@@ -2068,3 +2071,93 @@ cond.end:                                         ; preds = %entry, %cond.false
   %conv = sext i3 %cond to i8
   ret i8 %conv
 }
+
+define i32 @pr67843(i8 %0) {
+; CHECK-LABEL: @pr67843(
+; CHECK-NEXT:  start:
+; CHECK-NEXT:    [[SWITCH_TABLEIDX:%.*]] = sub nsw i8 [[TMP0:%.*]], -1
+; CHECK-NEXT:    [[SWITCH_IDX_CAST:%.*]] = zext i8 [[SWITCH_TABLEIDX]] to i32
+; CHECK-NEXT:    [[SWITCH_OFFSET:%.*]] = add i32 [[SWITCH_IDX_CAST]], 255
+; CHECK-NEXT:    [[SWITCH_MASKED:%.*]] = and i32 [[SWITCH_OFFSET]], 255
+; CHECK-NEXT:    ret i32 [[SWITCH_MASKED]]
+;
+start:
+  switch i8 %0, label %bb2 [
+  i8 0, label %bb5
+  i8 1, label %bb4
+  i8 -1, label %bb1
+  ]
+
+bb2:                                              ; preds = %start
+  unreachable
+
+bb4:                                              ; preds = %start
+  br label %bb5
+
+bb1:                                              ; preds = %start
+  br label %bb5
+
+bb5:                                              ; preds = %start, %bb1, %bb4
+  %.0 = phi i32 [ 255, %bb1 ], [ 1, %bb4 ], [ 0, %start ]
+  ret i32 %.0
+}
+
+define i32 @linearmap_masked_with_common_highbits(i8 %0) {
+; CHECK-LABEL: @linearmap_masked_with_common_highbits(
+; CHECK-NEXT:  start:
+; CHECK-NEXT:    [[SWITCH_TABLEIDX:%.*]] = sub nsw i8 [[TMP0:%.*]], -1
+; CHECK-NEXT:    [[SWITCH_IDX_CAST:%.*]] = zext i8 [[SWITCH_TABLEIDX]] to i32
+; CHECK-NEXT:    [[SWITCH_OFFSET:%.*]] = add i32 [[SWITCH_IDX_CAST]], 511
+; CHECK-NEXT:    [[SWITCH_MASKED:%.*]] = and i32 [[SWITCH_OFFSET]], 255
+; CHECK-NEXT:    [[SWITCH_WITH_HIGH_BITS:%.*]] = or i32 [[SWITCH_MASKED]], 256
+; CHECK-NEXT:    ret i32 [[SWITCH_WITH_HIGH_BITS]]
+;
+start:
+  switch i8 %0, label %bb2 [
+  i8 0, label %bb5
+  i8 1, label %bb4
+  i8 -1, label %bb1
+  ]
+
+bb2:                                              ; preds = %start
+  unreachable
+
+bb4:                                              ; preds = %start
+  br label %bb5
+
+bb1:                                              ; preds = %start
+  br label %bb5
+
+bb5:                                              ; preds = %start, %bb1, %bb4
+  %.0 = phi i32 [ 511, %bb1 ], [ 257, %bb4 ], [ 256, %start ]
+  ret i32 %.0
+}
+
+define i32 @linearmap_masked_with_common_highbits_fail(i8 %0) {
+; CHECK-LABEL: @linearmap_masked_with_common_highbits_fail(
+; CHECK-NEXT:  start:
+; CHECK-NEXT:    [[SWITCH_TABLEIDX:%.*]] = sub nsw i8 [[TMP0:%.*]], -1
+; CHECK-NEXT:    [[SWITCH_GEP:%.*]] = getelementptr inbounds [3 x i32], ptr @switch.table.linearmap_masked_with_common_highbits_fail, i32 0, i8 [[SWITCH_TABLEIDX]]
+; CHECK-NEXT:    [[SWITCH_LOAD:%.*]] = load i32, ptr [[SWITCH_GEP]], align 4
+; CHECK-NEXT:    ret i32 [[SWITCH_LOAD]]
+;
+start:
+  switch i8 %0, label %bb2 [
+  i8 0, label %bb5
+  i8 1, label %bb4
+  i8 -1, label %bb1
+  ]
+
+bb2:                                              ; preds = %start
+  unreachable
+
+bb4:                                              ; preds = %start
+  br label %bb5
+
+bb1:                                              ; preds = %start
+  br label %bb5
+
+bb5:                                              ; preds = %start, %bb1, %bb4
+  %.0 = phi i32 [ 1023, %bb1 ], [ 257, %bb4 ], [ 256, %start ]
+  ret i32 %.0
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/67881


More information about the llvm-commits mailing list