[llvm] [SwitchLowering] Support merging 0 and power-of-2 case. (PR #139736)
via llvm-commits
llvm-commits at lists.llvm.org
Tue May 27 04:03:02 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
Author: Florian Hahn (fhahn)
<details>
<summary>Changes</summary>
Add a new switch lowering kind to split off switch cases with 0 and another
power-of-2 constant to an AND + ICMP + BR. This removes a branch which
can be highly profitable.
Alive2 proof showing that a power-of-2 constant is required:
https://alive2.llvm.org/ce/z/VIMMNB.
---
Full diff: https://github.com/llvm/llvm-project/pull/139736.diff
6 Files Affected:
- (modified) llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h (+7-7)
- (modified) llvm/include/llvm/CodeGen/SwitchLoweringUtils.h (+4-1)
- (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+11-12)
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+14-2)
- (modified) llvm/lib/CodeGen/SwitchLoweringUtils.cpp (+34)
- (modified) llvm/test/CodeGen/AArch64/switch-cases-to-branch-and.ll (+57-71)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h
index 6fd05c8fddd5f..b1ae8171ce2d2 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h
@@ -405,13 +405,13 @@ class IRTranslator : public MachineFunctionPass {
BranchProbability UnhandledProbs, SwitchCG::CaseClusterIt I,
MachineBasicBlock *Fallthrough, bool FallthroughUnreachable);
- bool lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I, Value *Cond,
- MachineBasicBlock *Fallthrough,
- bool FallthroughUnreachable,
- BranchProbability UnhandledProbs,
- MachineBasicBlock *CurMBB,
- MachineIRBuilder &MIB,
- MachineBasicBlock *SwitchMBB);
+ bool lowerSwitchAndOrRangeWorkItem(SwitchCG::CaseClusterIt I, Value *Cond,
+ MachineBasicBlock *Fallthrough,
+ bool FallthroughUnreachable,
+ BranchProbability UnhandledProbs,
+ MachineBasicBlock *CurMBB,
+ MachineIRBuilder &MIB,
+ MachineBasicBlock *SwitchMBB);
bool lowerBitTestWorkItem(
SwitchCG::SwitchWorkListItem W, MachineBasicBlock *SwitchMBB,
diff --git a/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h b/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
index 9f1d6f7b4f952..6b7cb8d9ce45a 100644
--- a/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
+++ b/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
@@ -35,7 +35,8 @@ enum CaseClusterKind {
/// A cluster of cases suitable for jump table lowering.
CC_JumpTable,
/// A cluster of cases suitable for bit test lowering.
- CC_BitTests
+ CC_BitTests,
+ CC_And
};
/// A cluster of case labels.
@@ -141,6 +142,8 @@ struct CaseBlock {
BranchProbability TrueProb, FalseProb;
bool IsUnpredictable;
+ bool EmitAnd = false;
+
// Constructor for SelectionDAG.
CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
const Value *cmpmiddle, MachineBasicBlock *truebb,
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index fe5dcd14d8804..eb07a730ac8d5 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -1059,18 +1059,15 @@ bool IRTranslator::lowerJumpTableWorkItem(SwitchCG::SwitchWorkListItem W,
}
return true;
}
-bool IRTranslator::lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I,
- Value *Cond,
- MachineBasicBlock *Fallthrough,
- bool FallthroughUnreachable,
- BranchProbability UnhandledProbs,
- MachineBasicBlock *CurMBB,
- MachineIRBuilder &MIB,
- MachineBasicBlock *SwitchMBB) {
+bool IRTranslator::lowerSwitchAndOrRangeWorkItem(
+ SwitchCG::CaseClusterIt I, Value *Cond, MachineBasicBlock *Fallthrough,
+ bool FallthroughUnreachable, BranchProbability UnhandledProbs,
+ MachineBasicBlock *CurMBB, MachineIRBuilder &MIB,
+ MachineBasicBlock *SwitchMBB) {
using namespace SwitchCG;
const Value *RHS, *LHS, *MHS;
CmpInst::Predicate Pred;
- if (I->Low == I->High) {
+ if (I->Low == I->High || I->Kind == CC_And) {
// Check Cond == I->Low.
Pred = CmpInst::ICMP_EQ;
LHS = Cond;
@@ -1088,6 +1085,7 @@ bool IRTranslator::lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I,
// The false probability is the sum of all unhandled cases.
CaseBlock CB(Pred, FallthroughUnreachable, LHS, RHS, MHS, I->MBB, Fallthrough,
CurMBB, MIB.getDebugLoc(), I->Prob, UnhandledProbs);
+ CB.EmitAnd = I->Kind == CC_And;
emitSwitchCase(CB, SwitchMBB, MIB);
return true;
@@ -1327,10 +1325,11 @@ bool IRTranslator::lowerSwitchWorkItem(SwitchCG::SwitchWorkListItem W,
}
break;
}
+ case CC_And:
case CC_Range: {
- if (!lowerSwitchRangeWorkItem(I, Cond, Fallthrough,
- FallthroughUnreachable, UnhandledProbs,
- CurMBB, MIB, SwitchMBB)) {
+ if (!lowerSwitchAndOrRangeWorkItem(I, Cond, Fallthrough,
+ FallthroughUnreachable, UnhandledProbs,
+ CurMBB, MIB, SwitchMBB)) {
LLVM_DEBUG(dbgs() << "Failed to lower switch range");
return false;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index ca195cb37de8a..f1a2b07180cd4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -2857,7 +2857,17 @@ void SelectionDAGBuilder::visitSwitchCase(CaseBlock &CB,
EVT MemVT = TLI.getMemValueType(DAG.getDataLayout(), CB.CmpLHS->getType());
// Build the setcc now.
- if (!CB.CmpMHS) {
+ if (CB.EmitAnd) {
+ SDLoc dl = getCurSDLoc();
+
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ EVT VT = TLI.getValueType(DAG.getDataLayout(), CB.CmpRHS->getType(), true);
+ SDValue C = DAG.getConstant(*cast<ConstantInt>(CB.CmpRHS), dl, VT);
+ SDValue Zero = DAG.getConstant(0, dl, VT);
+ SDValue CondLHS = getValue(CB.CmpLHS);
+ SDValue And = DAG.getNode(ISD::AND, dl, C.getValueType(), CondLHS, C);
+ Cond = DAG.getSetCC(dl, MVT::i1, And, Zero, ISD::SETEQ);
+ } else if (!CB.CmpMHS) {
// Fold "(X == true)" to X and "(X == false)" to !X to
// handle common cases produced by branch lowering.
if (CB.CmpRHS == ConstantInt::getTrue(*DAG.getContext()) &&
@@ -12250,10 +12260,11 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
}
break;
}
+ case CC_And:
case CC_Range: {
const Value *RHS, *LHS, *MHS;
ISD::CondCode CC;
- if (I->Low == I->High) {
+ if (I->Low == I->High || I->Kind == CC_And) {
// Check Cond == I->Low.
CC = ISD::SETEQ;
LHS = Cond;
@@ -12275,6 +12286,7 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
CaseBlock CB(CC, LHS, RHS, MHS, I->MBB, Fallthrough, CurMBB,
getCurSDLoc(), I->Prob, UnhandledProbs);
+ CB.EmitAnd = I->Kind == CC_And;
if (CurMBB == SwitchMBB)
visitSwitchCase(CB, SwitchMBB);
else
diff --git a/llvm/lib/CodeGen/SwitchLoweringUtils.cpp b/llvm/lib/CodeGen/SwitchLoweringUtils.cpp
index 038c499fe236e..11122f120eddc 100644
--- a/llvm/lib/CodeGen/SwitchLoweringUtils.cpp
+++ b/llvm/lib/CodeGen/SwitchLoweringUtils.cpp
@@ -362,6 +362,40 @@ void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
}
}
Clusters.resize(DstIndex);
+
+ unsigned ZeroIdx = -1;
+ for (const auto &[Idx, C] : enumerate(Clusters)) {
+ if (C.Kind != CC_Range || C.Low != C.High)
+ continue;
+ if (C.Low->isZero()) {
+ ZeroIdx = Idx;
+ break;
+ }
+ }
+
+ if (ZeroIdx == -1u)
+ return;
+
+ unsigned Pow2Idx = -1;
+ for (const auto &[Idx, C] : enumerate(Clusters)) {
+ if (C.Kind != CC_Range || C.Low != C.High || C.MBB != Clusters[ZeroIdx].MBB)
+ continue;
+ if (C.Low->getValue().isPowerOf2()) {
+ Pow2Idx = Idx;
+ break;
+ }
+ }
+
+ if (Pow2Idx == -1u)
+ return;
+
+ APInt Pow2 = Clusters[Pow2Idx].Low->getValue();
+ APInt NewC = (Pow2 + 1) * -1;
+ Clusters[ZeroIdx].Low = ConstantInt::get(SI->getContext(), NewC);
+ Clusters[ZeroIdx].High = ConstantInt::get(SI->getContext(), NewC);
+ Clusters[ZeroIdx].Kind = CC_And;
+ Clusters[ZeroIdx].Prob += Clusters[Pow2Idx].Prob;
+ Clusters.erase(Clusters.begin() + Pow2Idx);
}
bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
diff --git a/llvm/test/CodeGen/AArch64/switch-cases-to-branch-and.ll b/llvm/test/CodeGen/AArch64/switch-cases-to-branch-and.ll
index 04d4ce8493e1b..b6f3da072ec63 100644
--- a/llvm/test/CodeGen/AArch64/switch-cases-to-branch-and.ll
+++ b/llvm/test/CodeGen/AArch64/switch-cases-to-branch-and.ll
@@ -4,30 +4,25 @@
define i32 @switch_with_matching_dests_0_and_pow2_3_cases(i8 %v) {
; CHECK-LABEL: switch_with_matching_dests_0_and_pow2_3_cases:
; CHECK: ; %bb.0: ; %entry
-; CHECK-NEXT: mov w9, #100 ; =0x64
-; CHECK-NEXT: mov w8, #20 ; =0x14
+; CHECK-NEXT: mov w8, #100 ; =0x64
+; CHECK-NEXT: mov w9, #223 ; =0xdf
; CHECK-NEXT: LBB0_1: ; %loop.header
; CHECK-NEXT: ; =>This Inner Loop Header: Depth=1
-; CHECK-NEXT: ands w10, w0, #0xff
-; CHECK-NEXT: b.eq LBB0_6
+; CHECK-NEXT: tst w0, w9
+; CHECK-NEXT: b.eq LBB0_4
; CHECK-NEXT: ; %bb.2: ; %loop.header
; CHECK-NEXT: ; in Loop: Header=BB0_1 Depth=1
-; CHECK-NEXT: cmp w10, #32
-; CHECK-NEXT: b.eq LBB0_6
-; CHECK-NEXT: ; %bb.3: ; %loop.header
-; CHECK-NEXT: ; in Loop: Header=BB0_1 Depth=1
+; CHECK-NEXT: and w10, w0, #0xff
; CHECK-NEXT: cmp w10, #124
-; CHECK-NEXT: b.eq LBB0_7
-; CHECK-NEXT: ; %bb.4: ; %loop.latch
+; CHECK-NEXT: b.eq LBB0_5
+; CHECK-NEXT: ; %bb.3: ; %loop.latch
; CHECK-NEXT: ; in Loop: Header=BB0_1 Depth=1
-; CHECK-NEXT: subs w9, w9, #1
+; CHECK-NEXT: subs w8, w8, #1
; CHECK-NEXT: b.ne LBB0_1
-; CHECK-NEXT: ; %bb.5:
-; CHECK-NEXT: mov w8, #20 ; =0x14
-; CHECK-NEXT: LBB0_6: ; %common.ret
-; CHECK-NEXT: mov w0, w8
+; CHECK-NEXT: LBB0_4:
+; CHECK-NEXT: mov w0, #20 ; =0x14
; CHECK-NEXT: ret
-; CHECK-NEXT: LBB0_7: ; %e2
+; CHECK-NEXT: LBB0_5: ; %e2
; CHECK-NEXT: mov w0, #30 ; =0x1e
; CHECK-NEXT: ret
entry:
@@ -56,30 +51,28 @@ e2:
define i32 @switch_with_matching_dests_0_and_pow2_3_cases_swapped(i8 %v) {
; CHECK-LABEL: switch_with_matching_dests_0_and_pow2_3_cases_swapped:
; CHECK: ; %bb.0: ; %entry
-; CHECK-NEXT: mov w9, #100 ; =0x64
-; CHECK-NEXT: mov w8, #20 ; =0x14
+; CHECK-NEXT: mov w8, #100 ; =0x64
+; CHECK-NEXT: mov w9, #223 ; =0xdf
; CHECK-NEXT: LBB1_1: ; %loop.header
; CHECK-NEXT: ; =>This Inner Loop Header: Depth=1
-; CHECK-NEXT: ands w10, w0, #0xff
-; CHECK-NEXT: b.eq LBB1_6
+; CHECK-NEXT: tst w0, w9
+; CHECK-NEXT: b.eq LBB1_5
; CHECK-NEXT: ; %bb.2: ; %loop.header
; CHECK-NEXT: ; in Loop: Header=BB1_1 Depth=1
-; CHECK-NEXT: cmp w10, #32
-; CHECK-NEXT: b.eq LBB1_6
-; CHECK-NEXT: ; %bb.3: ; %loop.header
-; CHECK-NEXT: ; in Loop: Header=BB1_1 Depth=1
+; CHECK-NEXT: and w10, w0, #0xff
; CHECK-NEXT: cmp w10, #124
-; CHECK-NEXT: b.eq LBB1_7
-; CHECK-NEXT: ; %bb.4: ; %loop.latch
+; CHECK-NEXT: b.eq LBB1_6
+; CHECK-NEXT: ; %bb.3: ; %loop.latch
; CHECK-NEXT: ; in Loop: Header=BB1_1 Depth=1
-; CHECK-NEXT: subs w9, w9, #1
+; CHECK-NEXT: subs w8, w8, #1
; CHECK-NEXT: b.ne LBB1_1
-; CHECK-NEXT: ; %bb.5:
-; CHECK-NEXT: mov w8, #10 ; =0xa
-; CHECK-NEXT: LBB1_6: ; %common.ret
-; CHECK-NEXT: mov w0, w8
+; CHECK-NEXT: ; %bb.4:
+; CHECK-NEXT: mov w0, #10 ; =0xa
; CHECK-NEXT: ret
-; CHECK-NEXT: LBB1_7: ; %e2
+; CHECK-NEXT: LBB1_5:
+; CHECK-NEXT: mov w0, #20 ; =0x14
+; CHECK-NEXT: ret
+; CHECK-NEXT: LBB1_6: ; %e2
; CHECK-NEXT: mov w0, #30 ; =0x1e
; CHECK-NEXT: ret
entry:
@@ -111,35 +104,33 @@ e2:
define i32 @switch_with_matching_dests_0_and_pow2_3_cases_with_phi(i8 %v, i1 %c) {
; CHECK-LABEL: switch_with_matching_dests_0_and_pow2_3_cases_with_phi:
; CHECK: ; %bb.0: ; %entry
-; CHECK-NEXT: tbz w1, #0, LBB2_8
+; CHECK-NEXT: tbz w1, #0, LBB2_6
; CHECK-NEXT: ; %bb.1: ; %loop.header.preheader
-; CHECK-NEXT: mov w9, #100 ; =0x64
-; CHECK-NEXT: mov w8, #20 ; =0x14
+; CHECK-NEXT: mov w8, #100 ; =0x64
+; CHECK-NEXT: mov w9, #223 ; =0xdf
; CHECK-NEXT: LBB2_2: ; %loop.header
; CHECK-NEXT: ; =>This Inner Loop Header: Depth=1
-; CHECK-NEXT: ands w10, w0, #0xff
+; CHECK-NEXT: tst w0, w9
; CHECK-NEXT: b.eq LBB2_7
; CHECK-NEXT: ; %bb.3: ; %loop.header
; CHECK-NEXT: ; in Loop: Header=BB2_2 Depth=1
-; CHECK-NEXT: cmp w10, #32
-; CHECK-NEXT: b.eq LBB2_7
-; CHECK-NEXT: ; %bb.4: ; %loop.header
-; CHECK-NEXT: ; in Loop: Header=BB2_2 Depth=1
+; CHECK-NEXT: and w10, w0, #0xff
; CHECK-NEXT: cmp w10, #124
-; CHECK-NEXT: b.eq LBB2_9
-; CHECK-NEXT: ; %bb.5: ; %loop.latch
+; CHECK-NEXT: b.eq LBB2_8
+; CHECK-NEXT: ; %bb.4: ; %loop.latch
; CHECK-NEXT: ; in Loop: Header=BB2_2 Depth=1
-; CHECK-NEXT: subs w9, w9, #1
+; CHECK-NEXT: subs w8, w8, #1
; CHECK-NEXT: b.ne LBB2_2
-; CHECK-NEXT: ; %bb.6:
-; CHECK-NEXT: mov w8, #10 ; =0xa
-; CHECK-NEXT: LBB2_7: ; %common.ret
-; CHECK-NEXT: mov w0, w8
+; CHECK-NEXT: ; %bb.5:
+; CHECK-NEXT: mov w0, #10 ; =0xa
; CHECK-NEXT: ret
-; CHECK-NEXT: LBB2_8:
+; CHECK-NEXT: LBB2_6:
; CHECK-NEXT: mov w0, wzr
; CHECK-NEXT: ret
-; CHECK-NEXT: LBB2_9: ; %e2
+; CHECK-NEXT: LBB2_7:
+; CHECK-NEXT: mov w0, #20 ; =0x14
+; CHECK-NEXT: ret
+; CHECK-NEXT: LBB2_8: ; %e2
; CHECK-NEXT: mov w0, #30 ; =0x1e
; CHECK-NEXT: ret
entry:
@@ -240,21 +231,18 @@ define i32 @switch_in_loop_with_matching_dests_0_and_pow2_3_cases(ptr %start) {
; CHECK-NEXT: LBB4_1: ; %loop
; CHECK-NEXT: ; =>This Inner Loop Header: Depth=1
; CHECK-NEXT: ldrb w9, [x8], #1
-; CHECK-NEXT: cbz w9, LBB4_4
+; CHECK-NEXT: tst w9, #0xffffffdf
+; CHECK-NEXT: b.eq LBB4_4
; CHECK-NEXT: ; %bb.2: ; %loop
; CHECK-NEXT: ; in Loop: Header=BB4_1 Depth=1
; CHECK-NEXT: cmp w9, #124
-; CHECK-NEXT: b.eq LBB4_5
-; CHECK-NEXT: ; %bb.3: ; %loop
-; CHECK-NEXT: ; in Loop: Header=BB4_1 Depth=1
-; CHECK-NEXT: cmp w9, #32
; CHECK-NEXT: b.ne LBB4_1
+; CHECK-NEXT: ; %bb.3: ; %e2.loopexit
+; CHECK-NEXT: mov w0, wzr
+; CHECK-NEXT: ret
; CHECK-NEXT: LBB4_4: ; %e1
; CHECK-NEXT: mov w0, #-1 ; =0xffffffff
; CHECK-NEXT: ret
-; CHECK-NEXT: LBB4_5: ; %e2.loopexit
-; CHECK-NEXT: mov w0, wzr
-; CHECK-NEXT: ret
entry:
br label %loop
@@ -376,8 +364,7 @@ define void @test_successor_with_loop_phi(ptr %A, ptr %B) {
; CHECK-NEXT: ldr w8, [x0]
; CHECK-NEXT: str wzr, [x0]
; CHECK-NEXT: mov x0, x1
-; CHECK-NEXT: orr w8, w8, #0x4
-; CHECK-NEXT: cmp w8, #4
+; CHECK-NEXT: tst w8, #0xfffffffb
; CHECK-NEXT: b.eq LBB7_1
; CHECK-NEXT: ; %bb.2: ; %exit
; CHECK-NEXT: ret
@@ -556,22 +543,21 @@ e1:
define void @merge_with_stores(ptr %A, i16 %v) {
; CHECK-LABEL: merge_with_stores:
; CHECK: ; %bb.0: ; %entry
-; CHECK-NEXT: and w8, w1, #0xffff
-; CHECK-NEXT: sub w9, w8, #10
-; CHECK-NEXT: cmp w9, #2
-; CHECK-NEXT: b.lo LBB11_4
+; CHECK-NEXT: mov w8, #65533 ; =0xfffd
+; CHECK-NEXT: tst w1, w8
+; CHECK-NEXT: b.eq LBB11_3
; CHECK-NEXT: ; %bb.1: ; %entry
-; CHECK-NEXT: cbz w8, LBB11_5
-; CHECK-NEXT: ; %bb.2: ; %entry
+; CHECK-NEXT: and w8, w1, #0xffff
+; CHECK-NEXT: sub w8, w8, #10
; CHECK-NEXT: cmp w8, #2
-; CHECK-NEXT: b.eq LBB11_5
-; CHECK-NEXT: ; %bb.3: ; %default.dst
-; CHECK-NEXT: strh wzr, [x0]
-; CHECK-NEXT: ret
-; CHECK-NEXT: LBB11_4: ; %other.dst
+; CHECK-NEXT: b.hs LBB11_4
+; CHECK-NEXT: ; %bb.2: ; %other.dst
; CHECK-NEXT: mov w8, #1 ; =0x1
; CHECK-NEXT: strh w8, [x0, #36]
-; CHECK-NEXT: LBB11_5: ; %pow2.dst
+; CHECK-NEXT: LBB11_3: ; %pow2.dst
+; CHECK-NEXT: ret
+; CHECK-NEXT: LBB11_4: ; %default.dst
+; CHECK-NEXT: strh wzr, [x0]
; CHECK-NEXT: ret
entry:
switch i16 %v, label %default.dst [
``````````
</details>
https://github.com/llvm/llvm-project/pull/139736
More information about the llvm-commits
mailing list