[llvm] [SimplifyCFG] Improve linear mapping in switch lookup tables (PR #67881)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Sun Nov 12 01:14:50 PST 2023


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/67881

>From f90732369759224f776049ca19f6f4c6edffbcf4 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 12 Nov 2023 16:58:09 +0800
Subject: [PATCH 1/2] [SimplifyCFG] Add pre-commit tests from PR67843. NFC.

---
 .../SimplifyCFG/X86/switch_to_lookup_table.ll | 87 +++++++++++++++++++
 1 file changed, 87 insertions(+)

diff --git a/llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll b/llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll
index 3873f0c0ae0bbd5..ee33ce702445ab1 100644
--- a/llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll
+++ b/llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll
@@ -2068,3 +2068,90 @@ cond.end:                                         ; preds = %entry, %cond.false
   %conv = sext i3 %cond to i8
   ret i8 %conv
 }
+
+define i32 @pr67843(i8 %0) {
+; CHECK-LABEL: @pr67843(
+; CHECK-NEXT:  start:
+; CHECK-NEXT:    [[SWITCH_TABLEIDX:%.*]] = sub nsw i8 [[TMP0:%.*]], -1
+; CHECK-NEXT:    [[SWITCH_GEP:%.*]] = getelementptr inbounds [3 x i32], ptr @switch.table.pr67843, i32 0, i8 [[SWITCH_TABLEIDX]]
+; CHECK-NEXT:    [[SWITCH_LOAD:%.*]] = load i32, ptr [[SWITCH_GEP]], align 4
+; CHECK-NEXT:    ret i32 [[SWITCH_LOAD]]
+;
+start:
+  switch i8 %0, label %bb2 [
+  i8 0, label %bb5
+  i8 1, label %bb4
+  i8 -1, label %bb1
+  ]
+
+bb2:                                              ; preds = %start
+  unreachable
+
+bb4:                                              ; preds = %start
+  br label %bb5
+
+bb1:                                              ; preds = %start
+  br label %bb5
+
+bb5:                                              ; preds = %start, %bb1, %bb4
+  %.0 = phi i32 [ 255, %bb1 ], [ 1, %bb4 ], [ 0, %start ]
+  ret i32 %.0
+}
+
+define i32 @linearmap_masked_with_common_highbits(i8 %0) {
+; CHECK-LABEL: @linearmap_masked_with_common_highbits(
+; CHECK-NEXT:  start:
+; CHECK-NEXT:    [[SWITCH_TABLEIDX:%.*]] = sub nsw i8 [[TMP0:%.*]], -1
+; CHECK-NEXT:    [[SWITCH_GEP:%.*]] = getelementptr inbounds [3 x i32], ptr @switch.table.linearmap_masked_with_common_highbits, i32 0, i8 [[SWITCH_TABLEIDX]]
+; CHECK-NEXT:    [[SWITCH_LOAD:%.*]] = load i32, ptr [[SWITCH_GEP]], align 4
+; CHECK-NEXT:    ret i32 [[SWITCH_LOAD]]
+;
+start:
+  switch i8 %0, label %bb2 [
+  i8 0, label %bb5
+  i8 1, label %bb4
+  i8 -1, label %bb1
+  ]
+
+bb2:                                              ; preds = %start
+  unreachable
+
+bb4:                                              ; preds = %start
+  br label %bb5
+
+bb1:                                              ; preds = %start
+  br label %bb5
+
+bb5:                                              ; preds = %start, %bb1, %bb4
+  %.0 = phi i32 [ 511, %bb1 ], [ 257, %bb4 ], [ 256, %start ]
+  ret i32 %.0
+}
+
+define i32 @linearmap_masked_with_common_highbits_fail(i8 %0) {
+; CHECK-LABEL: @linearmap_masked_with_common_highbits_fail(
+; CHECK-NEXT:  start:
+; CHECK-NEXT:    [[SWITCH_TABLEIDX:%.*]] = sub nsw i8 [[TMP0:%.*]], -1
+; CHECK-NEXT:    [[SWITCH_GEP:%.*]] = getelementptr inbounds [3 x i32], ptr @switch.table.linearmap_masked_with_common_highbits_fail, i32 0, i8 [[SWITCH_TABLEIDX]]
+; CHECK-NEXT:    [[SWITCH_LOAD:%.*]] = load i32, ptr [[SWITCH_GEP]], align 4
+; CHECK-NEXT:    ret i32 [[SWITCH_LOAD]]
+;
+start:
+  switch i8 %0, label %bb2 [
+  i8 0, label %bb5
+  i8 1, label %bb4
+  i8 -1, label %bb1
+  ]
+
+bb2:                                              ; preds = %start
+  unreachable
+
+bb4:                                              ; preds = %start
+  br label %bb5
+
+bb1:                                              ; preds = %start
+  br label %bb5
+
+bb5:                                              ; preds = %start, %bb1, %bb4
+  %.0 = phi i32 [ 1023, %bb1 ], [ 257, %bb4 ], [ 256, %start ]
+  ret i32 %.0
+}

>From f622eca95b4536fccdb5b512d4f87ac68cf33638 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 12 Nov 2023 17:13:38 +0800
Subject: [PATCH 2/2] [SimplifyCFG] Improve linear mapping in switch lookup
 tables

---
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp     | 116 +++++++++++++-----
 .../SimplifyCFG/X86/switch-table-bug.ll       |   8 +-
 .../SimplifyCFG/X86/switch_to_lookup_table.ll |  15 ++-
 3 files changed, 96 insertions(+), 43 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 4dae52a8ecffdf6..6b89390ab2cda3b 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -6082,6 +6082,8 @@ class SwitchLookupTable {
   ConstantInt *LinearOffset = nullptr;
   ConstantInt *LinearMultiplier = nullptr;
   bool LinearMapValWrapped = false;
+  unsigned LinearMapValMaskedBits = 0;
+  APInt LinearMapValHighBits;
 
   // For ArrayKind, this is the array.
   GlobalVariable *Array = nullptr;
@@ -6139,46 +6141,81 @@ SwitchLookupTable::SwitchLookupTable(
   // Check if we can derive the value with a linear transformation from the
   // table index.
   if (isa<IntegerType>(ValueType)) {
-    bool LinearMappingPossible = true;
-    APInt PrevVal;
-    APInt DistToPrev;
-    // When linear map is monotonic and signed overflow doesn't happen on
-    // maximum index, we can attach nsw on Add and Mul.
-    bool NonMonotonic = false;
-    assert(TableSize >= 2 && "Should be a SingleValue table.");
-    // Check if there is the same distance between two consecutive values.
-    for (uint64_t I = 0; I < TableSize; ++I) {
-      ConstantInt *ConstVal = dyn_cast<ConstantInt>(TableContents[I]);
-      if (!ConstVal) {
-        // This is an undef. We could deal with it, but undefs in lookup tables
-        // are very seldom. It's probably not worth the additional complexity.
-        LinearMappingPossible = false;
-        break;
-      }
-      const APInt &Val = ConstVal->getValue();
-      if (I != 0) {
-        APInt Dist = Val - PrevVal;
-        if (I == 1) {
-          DistToPrev = Dist;
-        } else if (Dist != DistToPrev) {
-          LinearMappingPossible = false;
-          break;
+    auto MatchLinearMapping = [&](bool MaskOutHighBits, unsigned LowBits) {
+      APInt PrevVal;
+      APInt DistToPrev;
+      // When linear map is monotonic and signed overflow doesn't happen on
+      // maximum index, we can attach nsw on Add and Mul.
+      bool NonMonotonic = false;
+      assert(TableSize >= 2 && "Should be a SingleValue table.");
+      // Check if there is the same distance between two consecutive values.
+      for (uint64_t I = 0; I < TableSize; ++I) {
+        ConstantInt *ConstVal = dyn_cast<ConstantInt>(TableContents[I]);
+        if (!ConstVal) {
+          // This is an undef. We could deal with it, but undefs in lookup
+          // tables are very seldom. It's probably not worth the additional
+          // complexity.
+          return false;
         }
-        NonMonotonic |=
-            Dist.isStrictlyPositive() ? Val.sle(PrevVal) : Val.sgt(PrevVal);
+        const APInt &Val = ConstVal->getValue();
+        if (I != 0) {
+          APInt Dist = Val - PrevVal;
+          if (MaskOutHighBits)
+            Dist = Dist.getLoBits(LowBits);
+          if (I == 1)
+            DistToPrev = Dist;
+          else if (Dist != DistToPrev)
+            return false;
+          if (!MaskOutHighBits)
+            NonMonotonic |=
+                Dist.isStrictlyPositive() ? Val.sle(PrevVal) : Val.sgt(PrevVal);
+        }
+        PrevVal = Val;
       }
-      PrevVal = Val;
-    }
-    if (LinearMappingPossible) {
+
       LinearOffset = cast<ConstantInt>(TableContents[0]);
       LinearMultiplier = ConstantInt::get(M.getContext(), DistToPrev);
-      bool MayWrap = false;
-      APInt M = LinearMultiplier->getValue();
-      (void)M.smul_ov(APInt(M.getBitWidth(), TableSize - 1), MayWrap);
-      LinearMapValWrapped = NonMonotonic || MayWrap;
+      if (MaskOutHighBits)
+        LinearMapValWrapped = true;
+      else {
+        bool MayWrap = false;
+        APInt M = LinearMultiplier->getValue();
+        (void)M.smul_ov(APInt(M.getBitWidth(), TableSize - 1), MayWrap);
+        LinearMapValWrapped = NonMonotonic || MayWrap;
+      }
       Kind = LinearMapKind;
       ++NumLinearMaps;
+      return true;
+    };
+
+    if (MatchLinearMapping(/* MaskOutHighBits */ false, /* LowBits */ 0))
       return;
+    // Try matching highbits | ((offset + index * multiplier) & lowbits_mask)
+    APInt CommonOnes = APInt::getAllOnes(ValueType->getScalarSizeInBits());
+    APInt CommonZeros = APInt::getAllOnes(ValueType->getScalarSizeInBits());
+    bool IsCommonBitsValid = true;
+    for (uint64_t I = 0; I < TableSize; ++I) {
+      ConstantInt *ConstVal = dyn_cast<ConstantInt>(TableContents[I]);
+      if (!ConstVal) {
+        // ignore undefs
+        IsCommonBitsValid = false;
+        break;
+      }
+      const APInt &Val = ConstVal->getValue();
+      CommonOnes &= Val;
+      CommonZeros &= ~Val;
+    }
+    if (IsCommonBitsValid) {
+      unsigned CommonHighBits = (CommonOnes | CommonZeros).countLeadingOnes();
+      unsigned LowBits = CommonOnes.getBitWidth() - CommonHighBits;
+      assert(LowBits > 0 && "Should be a SingleValue table.");
+      if (CommonHighBits > 0 &&
+          MatchLinearMapping(/* MaskOutHighBits */ true, LowBits)) {
+        LinearMapValMaskedBits = LowBits;
+        LinearMapValHighBits = CommonOnes;
+        LinearMapValHighBits.clearLowBits(LowBits);
+        return;
+      }
     }
   }
 
@@ -6232,6 +6269,19 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) {
       Result = Builder.CreateAdd(Result, LinearOffset, "switch.offset",
                                  /*HasNUW = */ false,
                                  /*HasNSW = */ !LinearMapValWrapped);
+
+    if (LinearMapValMaskedBits) {
+      Result = Builder.CreateAnd(
+          Result,
+          APInt::getLowBitsSet(
+              cast<IntegerType>(Result->getType())->getBitWidth(),
+              LinearMapValMaskedBits),
+          "switch.masked");
+      if (!LinearMapValHighBits.isZero())
+        Result = Builder.CreateOr(Result, LinearMapValHighBits,
+                                  "switch.with_high_bits");
+    }
+
     return Result;
   }
   case BitMapKind: {
diff --git a/llvm/test/Transforms/SimplifyCFG/X86/switch-table-bug.ll b/llvm/test/Transforms/SimplifyCFG/X86/switch-table-bug.ll
index 37001f4fba2aa84..f865387a24f338a 100644
--- a/llvm/test/Transforms/SimplifyCFG/X86/switch-table-bug.ll
+++ b/llvm/test/Transforms/SimplifyCFG/X86/switch-table-bug.ll
@@ -10,10 +10,10 @@ define i64 @_TFO6reduce1E5toRawfS0_FT_Si(i2) {
 ; CHECK-LABEL: @_TFO6reduce1E5toRawfS0_FT_Si(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[SWITCH_TABLEIDX:%.*]] = sub i2 [[TMP0:%.*]], -2
-; CHECK-NEXT:    [[SWITCH_TABLEIDX_ZEXT:%.*]] = zext i2 [[SWITCH_TABLEIDX]] to i3
-; CHECK-NEXT:    [[SWITCH_GEP:%.*]] = getelementptr inbounds [4 x i64], ptr @switch.table._TFO6reduce1E5toRawfS0_FT_Si, i32 0, i3 [[SWITCH_TABLEIDX_ZEXT]]
-; CHECK-NEXT:    [[SWITCH_LOAD:%.*]] = load i64, ptr [[SWITCH_GEP]], align 8
-; CHECK-NEXT:    ret i64 [[SWITCH_LOAD]]
+; CHECK-NEXT:    [[SWITCH_IDX_CAST:%.*]] = zext i2 [[SWITCH_TABLEIDX]] to i64
+; CHECK-NEXT:    [[SWITCH_OFFSET:%.*]] = add i64 [[SWITCH_IDX_CAST]], 2
+; CHECK-NEXT:    [[SWITCH_MASKED:%.*]] = and i64 [[SWITCH_OFFSET]], 3
+; CHECK-NEXT:    ret i64 [[SWITCH_MASKED]]
 ;
 entry:
   switch i2 %0, label %1 [
diff --git a/llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll b/llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll
index ee33ce702445ab1..1ddb1d0a79b5912 100644
--- a/llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll
+++ b/llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll
@@ -2073,9 +2073,10 @@ define i32 @pr67843(i8 %0) {
 ; CHECK-LABEL: @pr67843(
 ; CHECK-NEXT:  start:
 ; CHECK-NEXT:    [[SWITCH_TABLEIDX:%.*]] = sub nsw i8 [[TMP0:%.*]], -1
-; CHECK-NEXT:    [[SWITCH_GEP:%.*]] = getelementptr inbounds [3 x i32], ptr @switch.table.pr67843, i32 0, i8 [[SWITCH_TABLEIDX]]
-; CHECK-NEXT:    [[SWITCH_LOAD:%.*]] = load i32, ptr [[SWITCH_GEP]], align 4
-; CHECK-NEXT:    ret i32 [[SWITCH_LOAD]]
+; CHECK-NEXT:    [[SWITCH_IDX_CAST:%.*]] = zext i8 [[SWITCH_TABLEIDX]] to i32
+; CHECK-NEXT:    [[SWITCH_OFFSET:%.*]] = add i32 [[SWITCH_IDX_CAST]], 255
+; CHECK-NEXT:    [[SWITCH_MASKED:%.*]] = and i32 [[SWITCH_OFFSET]], 255
+; CHECK-NEXT:    ret i32 [[SWITCH_MASKED]]
 ;
 start:
   switch i8 %0, label %bb2 [
@@ -2102,9 +2103,11 @@ define i32 @linearmap_masked_with_common_highbits(i8 %0) {
 ; CHECK-LABEL: @linearmap_masked_with_common_highbits(
 ; CHECK-NEXT:  start:
 ; CHECK-NEXT:    [[SWITCH_TABLEIDX:%.*]] = sub nsw i8 [[TMP0:%.*]], -1
-; CHECK-NEXT:    [[SWITCH_GEP:%.*]] = getelementptr inbounds [3 x i32], ptr @switch.table.linearmap_masked_with_common_highbits, i32 0, i8 [[SWITCH_TABLEIDX]]
-; CHECK-NEXT:    [[SWITCH_LOAD:%.*]] = load i32, ptr [[SWITCH_GEP]], align 4
-; CHECK-NEXT:    ret i32 [[SWITCH_LOAD]]
+; CHECK-NEXT:    [[SWITCH_IDX_CAST:%.*]] = zext i8 [[SWITCH_TABLEIDX]] to i32
+; CHECK-NEXT:    [[SWITCH_OFFSET:%.*]] = add i32 [[SWITCH_IDX_CAST]], 511
+; CHECK-NEXT:    [[SWITCH_MASKED:%.*]] = and i32 [[SWITCH_OFFSET]], 255
+; CHECK-NEXT:    [[SWITCH_WITH_HIGH_BITS:%.*]] = or i32 [[SWITCH_MASKED]], 256
+; CHECK-NEXT:    ret i32 [[SWITCH_WITH_HIGH_BITS]]
 ;
 start:
   switch i8 %0, label %bb2 [



More information about the llvm-commits mailing list