[llvm] [SimplifyCFG] Improve range reducing for switches (PR #67882)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Sep 30 06:11:47 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
<details>
<summary>Changes</summary>
This patch improves range reducing for switches when the default block is unreachable.
It will mask out the common high bits and move the largest hole to the back.
Example:
```
int f1(int x) {
switch (x) {
case 0: return 1;
case 1: return 2;
case 255: return 0;
default: __builtin_unreachable();
}
}
```
after range reduction:
```
int f1(int x) {
switch ((x + 1) & 255) {
case 1: return 1;
case 2: return 2;
case 0: return 0;
default: __builtin_unreachable();
}
}
```
Fixes #<!-- -->67842.
---
Full diff: https://github.com/llvm/llvm-project/pull/67882.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+71-3)
- (modified) llvm/test/Transforms/SimplifyCFG/rangereduce.ll (+139)
``````````diff
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 35fead111aa9666..ef7a8d34b539662 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -80,6 +80,7 @@
#include <cstddef>
#include <cstdint>
#include <iterator>
+#include <limits>
#include <map>
#include <optional>
#include <set>
@@ -6748,6 +6749,71 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
return true;
}
+/// Try to reduce the range of cases with an unreachable default.
+static bool
+ReduceSwitchRangeWithUnreachableDefault(SwitchInst *SI,
+ const SmallVectorImpl<int64_t> &Values,
+ uint64_t Base, IRBuilder<> &Builder) {
+ bool HasDefault =
+ !isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg());
+ if (HasDefault)
+ return false;
+
+ // Try reducing the range to (idx + offset) & mask
+ // Mask out common high bits
+ uint64_t CommonOnes = std::numeric_limits<uint64_t>::max();
+ uint64_t CommonZeros = std::numeric_limits<uint64_t>::max();
+ for (auto &V : Values) {
+ CommonOnes &= (uint64_t)V;
+ CommonZeros &= ~(uint64_t)V;
+ }
+ uint64_t CommonBits = countl_one(CommonOnes | CommonZeros);
+ unsigned LowBits = 64 - CommonBits;
+ uint64_t Mask = (1ULL << LowBits) - 1;
+ if (Mask == std::numeric_limits<uint64_t>::max())
+ return false;
+ // Now we have some case values in the additive group Z/(2**k)Z.
+ // Find the largest hole in the group and move it to back.
+ uint64_t MaxHole = 0;
+ uint64_t BestOffset = 0;
+ for (unsigned I = 0; I < Values.size(); ++I) {
+ uint64_t Hole = ((uint64_t)Values[I] -
+ (uint64_t)(I == 0 ? Values.back() : Values[I - 1])) &
+ Mask;
+ if (Hole > MaxHole) {
+ MaxHole = Hole;
+ BestOffset = Mask - (uint64_t)Values[I] + 1;
+ }
+ }
+
+ SmallVector<int64_t, 4> NewValues;
+ for (auto &V : Values)
+ NewValues.push_back(
+ (((int64_t)(((uint64_t)V + BestOffset) & Mask)) << CommonBits) >>
+ CommonBits);
+
+ llvm::sort(NewValues);
+ if (!isSwitchDense(NewValues))
+ // Transform didn't create a dense switch.
+ return false;
+
+ auto *Ty = cast<IntegerType>(SI->getCondition()->getType());
+ APInt Offset(Ty->getBitWidth(), BestOffset - Base);
+ auto *Index = Builder.CreateAnd(
+ Builder.CreateAdd(SI->getCondition(), ConstantInt::get(Ty, Offset)),
+ Mask);
+ SI->replaceUsesOfWith(SI->getCondition(), Index);
+
+ for (auto Case : SI->cases()) {
+ auto *Orig = Case.getCaseValue();
+ auto CaseVal =
+ (Orig->getValue() + Offset).trunc(LowBits).sext(Ty->getBitWidth());
+ Case.setValue(cast<ConstantInt>(ConstantInt::get(Ty, CaseVal)));
+ }
+
+ return true;
+}
+
/// Try to transform a switch that has "holes" in it to a contiguous sequence
/// of cases.
///
@@ -6763,9 +6829,8 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
if (CondTy->getIntegerBitWidth() > 64 ||
!DL.fitsInLegalInteger(CondTy->getIntegerBitWidth()))
return false;
- // Only bother with this optimization if there are more than 3 switch cases;
- // SDAG will only bother creating jump tables for 4 or more cases.
- if (SI->getNumCases() < 4)
+ // Ignore switches with less than three cases.
+ if (SI->getNumCases() < 3)
return false;
// This transform is agnostic to the signedness of the input or case values. We
@@ -6786,6 +6851,9 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
for (auto &V : Values)
V -= (uint64_t)(Base);
+ if (ReduceSwitchRangeWithUnreachableDefault(SI, Values, Base, Builder))
+ return true;
+
// Now we have signed numbers that have been shifted so that, given enough
// precision, there are no negative values. Since the rest of the transform
// is bitwise only, we switch now to an unsigned representation.
diff --git a/llvm/test/Transforms/SimplifyCFG/rangereduce.ll b/llvm/test/Transforms/SimplifyCFG/rangereduce.ll
index b1a3802a2bb58b8..939c7b782d579da 100644
--- a/llvm/test/Transforms/SimplifyCFG/rangereduce.ll
+++ b/llvm/test/Transforms/SimplifyCFG/rangereduce.ll
@@ -315,3 +315,142 @@ three:
ret i32 99783
}
+define i8 @pr67842(i32 %0) {
+; CHECK-LABEL: @pr67842(
+; CHECK-NEXT: start:
+; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[TMP0:%.*]], 1
+; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 255
+; CHECK-NEXT: [[SWITCH_IDX_CAST:%.*]] = trunc i32 [[TMP2]] to i8
+; CHECK-NEXT: [[SWITCH_OFFSET:%.*]] = add nsw i8 [[SWITCH_IDX_CAST]], -1
+; CHECK-NEXT: ret i8 [[SWITCH_OFFSET]]
+;
+start:
+ switch i32 %0, label %bb2 [
+ i32 0, label %bb5
+ i32 1, label %bb4
+ i32 255, label %bb1
+ ]
+
+bb2: ; preds = %start
+ unreachable
+
+bb4: ; preds = %start
+ br label %bb5
+
+bb1: ; preds = %start
+ br label %bb5
+
+bb5: ; preds = %start, %bb1, %bb4
+ %.0 = phi i8 [ -1, %bb1 ], [ 1, %bb4 ], [ 0, %start ]
+ ret i8 %.0
+}
+
+define i8 @reduce_masked_common_high_bits(i32 %0) {
+; CHECK-LABEL: @reduce_masked_common_high_bits(
+; CHECK-NEXT: start:
+; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[TMP0:%.*]], -127
+; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 127
+; CHECK-NEXT: [[SWITCH_IDX_CAST:%.*]] = trunc i32 [[TMP2]] to i8
+; CHECK-NEXT: [[SWITCH_OFFSET:%.*]] = add nsw i8 [[SWITCH_IDX_CAST]], -1
+; CHECK-NEXT: ret i8 [[SWITCH_OFFSET]]
+;
+start:
+ switch i32 %0, label %bb2 [
+ i32 128, label %bb5
+ i32 129, label %bb4
+ i32 255, label %bb1
+ ]
+
+bb2: ; preds = %start
+ unreachable
+
+bb4: ; preds = %start
+ br label %bb5
+
+bb1: ; preds = %start
+ br label %bb5
+
+bb5: ; preds = %start, %bb1, %bb4
+ %.0 = phi i8 [ -1, %bb1 ], [ 1, %bb4 ], [ 0, %start ]
+ ret i8 %.0
+}
+
+define i8 @reduce_masked_common_high_bits_fail(i32 %0) {
+; CHECK-LABEL: @reduce_masked_common_high_bits_fail(
+; CHECK-NEXT: start:
+; CHECK-NEXT: switch i32 [[TMP0:%.*]], label [[BB2:%.*]] [
+; CHECK-NEXT: i32 128, label [[BB5:%.*]]
+; CHECK-NEXT: i32 129, label [[BB4:%.*]]
+; CHECK-NEXT: i32 511, label [[BB1:%.*]]
+; CHECK-NEXT: ]
+; CHECK: bb2:
+; CHECK-NEXT: unreachable
+; CHECK: bb4:
+; CHECK-NEXT: br label [[BB5]]
+; CHECK: bb1:
+; CHECK-NEXT: br label [[BB5]]
+; CHECK: bb5:
+; CHECK-NEXT: [[DOT0:%.*]] = phi i8 [ -1, [[BB1]] ], [ 1, [[BB4]] ], [ 0, [[START:%.*]] ]
+; CHECK-NEXT: ret i8 [[DOT0]]
+;
+start:
+ switch i32 %0, label %bb2 [
+ i32 128, label %bb5
+ i32 129, label %bb4
+ i32 511, label %bb1
+ ]
+
+bb2: ; preds = %start
+ unreachable
+
+bb4: ; preds = %start
+ br label %bb5
+
+bb1: ; preds = %start
+ br label %bb5
+
+bb5: ; preds = %start, %bb1, %bb4
+ %.0 = phi i8 [ -1, %bb1 ], [ 1, %bb4 ], [ 0, %start ]
+ ret i8 %.0
+}
+
+; Optimization shouldn't trigger; The default block is reachable.
+define i8 @reduce_masked_default_reachable(i32 %0) {
+; CHECK-LABEL: @reduce_masked_default_reachable(
+; CHECK-NEXT: start:
+; CHECK-NEXT: switch i32 [[TMP0:%.*]], label [[COMMON_RET:%.*]] [
+; CHECK-NEXT: i32 0, label [[BB5:%.*]]
+; CHECK-NEXT: i32 1, label [[BB4:%.*]]
+; CHECK-NEXT: i32 255, label [[BB1:%.*]]
+; CHECK-NEXT: ]
+; CHECK: common.ret:
+; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i8 [ [[DOT0:%.*]], [[BB5]] ], [ 24, [[START:%.*]] ]
+; CHECK-NEXT: ret i8 [[COMMON_RET_OP]]
+; CHECK: bb4:
+; CHECK-NEXT: br label [[BB5]]
+; CHECK: bb1:
+; CHECK-NEXT: br label [[BB5]]
+; CHECK: bb5:
+; CHECK-NEXT: [[DOT0]] = phi i8 [ -1, [[BB1]] ], [ 1, [[BB4]] ], [ 0, [[START]] ]
+; CHECK-NEXT: br label [[COMMON_RET]]
+;
+start:
+ switch i32 %0, label %bb2 [
+ i32 0, label %bb5
+ i32 1, label %bb4
+ i32 255, label %bb1
+ ]
+
+bb2: ; preds = %start
+ ret i8 24
+
+bb4: ; preds = %start
+ br label %bb5
+
+bb1: ; preds = %start
+ br label %bb5
+
+bb5: ; preds = %start, %bb1, %bb4
+ %.0 = phi i8 [ -1, %bb1 ], [ 1, %bb4 ], [ 0, %start ]
+ ret i8 %.0
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/67882
More information about the llvm-commits
mailing list