[llvm] [CodeGen] Optimize switch on ucmp/scmp to branch sequence (PR #176582)

Kamini Banait via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 28 12:48:17 PST 2026


https://github.com/kamini08 updated https://github.com/llvm/llvm-project/pull/176582

>From 19bddd4da86b5f7efd063fd9622235f340ef85e8 Mon Sep 17 00:00:00 2001
From: kamini08 <kaminibanait03 at gmail.com>
Date: Sat, 17 Jan 2026 20:42:10 +0530
Subject: [PATCH 1/4] [CodeGenPrepare] Optimize switch on llvm.ucmp/llvm.scmp

---
 llvm/lib/CodeGen/CodeGenPrepare.cpp | 75 +++++++++++++++++++++++++++++
 1 file changed, 75 insertions(+)

diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 3185731c9350e..acb5d7ba8166c 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -9084,8 +9084,83 @@ bool CodeGenPrepare::makeBitReverse(Instruction &I) {
 // In this pass we look for GEP and cast instructions that are used
 // across basic blocks and rewrite them to improve basic-block-at-a-time
 // selection.
+// Converts switch(ucmp(x,y)) into direct branches to avoid materializing the
+// -1/0/1 value.
+static bool optimizeSwitchOnCompare(SwitchInst *SI, Function &F,
+                                    const TargetTransformInfo &TTI) {
+  Value *Cond = SI->getCondition();
+  auto *II = dyn_cast<IntrinsicInst>(Cond);
+  if (!II || (II->getIntrinsicID() != Intrinsic::ucmp &&
+              II->getIntrinsicID() != Intrinsic::scmp))
+    return false;
+
+  bool IsSigned = II->getIntrinsicID() == Intrinsic::scmp;
+  Value *LHS = II->getOperand(0);
+  Value *RHS = II->getOperand(1);
+
+  // 1. Map Targets (-1 -> Less, 0 -> Equal, 1 -> Greater)
+  BasicBlock *DestLess = SI->getDefaultDest();
+  BasicBlock *DestEqual = SI->getDefaultDest();
+  BasicBlock *DestGreater = SI->getDefaultDest();
+
+  for (auto Case : SI->cases()) {
+    int64_t Val = Case.getCaseValue()->getSExtValue();
+    if (Val == -1)
+      DestLess = Case.getCaseSuccessor();
+    else if (Val == 0)
+      DestEqual = Case.getCaseSuccessor();
+    else if (Val == 1)
+      DestGreater = Case.getCaseSuccessor();
+  }
+  BasicBlock *HeadBB = SI->getParent();
+  LLVMContext &Ctx = F.getContext();
+
+  // Create the intermediate block
+  BasicBlock *CheckEqBB = BasicBlock::Create(Ctx, "check.eq", &F);
+  // Insert it after HeadBB for readability
+  CheckEqBB->moveAfter(HeadBB);
+
+  // Compare Less
+  IRBuilder<> Builder(SI);
+  CmpInst::Predicate PredLess =
+      IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
+  Value *CmpLess = Builder.CreateICmp(PredLess, LHS, RHS, "cmp.less");
+
+  // Replace Switch with Branch
+  BranchInst::Create(DestLess, CheckEqBB, CmpLess, HeadBB);
+  SI->eraseFromParent();
+
+  // Compare Equal
+  Builder.SetInsertPoint(CheckEqBB);
+  Value *CmpEq = Builder.CreateICmp(ICmpInst::ICMP_EQ, LHS, RHS, "cmp.eq");
+  BranchInst::Create(DestEqual, DestGreater, CmpEq, CheckEqBB);
+
+  auto UpdatePhis = [&](BasicBlock *Dest) {
+    if (Dest == DestLess)
+      return; 
+    for (PHINode &PN : Dest->phis()) {
+      int Idx = PN.getBasicBlockIndex(HeadBB);
+      if (Idx != -1)
+        PN.setIncomingBlock(Idx, CheckEqBB);
+    }
+  };
+
+  UpdatePhis(DestEqual);
+  UpdatePhis(DestGreater);
+
+  return true;
+}
+
 bool CodeGenPrepare::optimizeBlock(BasicBlock &BB, ModifyDT &ModifiedDT) {
   SunkAddrs.clear();
+  
+  if (auto *SI = dyn_cast<SwitchInst>(BB.getTerminator())) {
+    if (optimizeSwitchOnCompare(SI, *BB.getParent(), *TTI)) {
+      ModifiedDT = ModifyDT::ModifyInstDT;
+      return true;
+    }
+  }
+
   bool MadeChange = false;
 
   do {

>From 0a20d90ebc07600148bd8f1e8411b673ec2d4fd6 Mon Sep 17 00:00:00 2001
From: kamini08 <kaminibanait03 at gmail.com>
Date: Sat, 17 Jan 2026 22:52:48 +0530
Subject: [PATCH 2/4] simplify PHI handling in switch-on-compare optimization

---
 llvm/lib/CodeGen/CodeGenPrepare.cpp | 19 +++++++++----------
 1 file changed, 9 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index acb5d7ba8166c..d10ea6604b27d 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -9112,12 +9112,17 @@ static bool optimizeSwitchOnCompare(SwitchInst *SI, Function &F,
     else if (Val == 1)
       DestGreater = Case.getCaseSuccessor();
   }
+  
+  if (DestLess == DestEqual || DestLess == DestGreater || 
+      DestEqual == DestGreater)
+    return false;
+  
   BasicBlock *HeadBB = SI->getParent();
   LLVMContext &Ctx = F.getContext();
 
   // Create the intermediate block
   BasicBlock *CheckEqBB = BasicBlock::Create(Ctx, "check.eq", &F);
-  // Insert it after HeadBB for readability
+
   CheckEqBB->moveAfter(HeadBB);
 
   // Compare Less
@@ -9134,20 +9139,14 @@ static bool optimizeSwitchOnCompare(SwitchInst *SI, Function &F,
   Builder.SetInsertPoint(CheckEqBB);
   Value *CmpEq = Builder.CreateICmp(ICmpInst::ICMP_EQ, LHS, RHS, "cmp.eq");
   BranchInst::Create(DestEqual, DestGreater, CmpEq, CheckEqBB);
-
-  auto UpdatePhis = [&](BasicBlock *Dest) {
-    if (Dest == DestLess)
-      return; 
+  
+  for (BasicBlock *Dest : {DestEqual, DestGreater}) {
     for (PHINode &PN : Dest->phis()) {
       int Idx = PN.getBasicBlockIndex(HeadBB);
       if (Idx != -1)
         PN.setIncomingBlock(Idx, CheckEqBB);
     }
-  };
-
-  UpdatePhis(DestEqual);
-  UpdatePhis(DestGreater);
-
+  }
   return true;
 }
 

>From 79da867dd35f439a016cf169de10e9663d5c6a0d Mon Sep 17 00:00:00 2001
From: kamini08 <kaminibanait03 at gmail.com>
Date: Tue, 20 Jan 2026 04:58:16 +0530
Subject: [PATCH 3/4] fix ucmp/scmp switch optimization

---
 llvm/lib/CodeGen/CodeGenPrepare.cpp           | 158 ++++++------
 .../Transforms/CodeGenPrepare/ucmp-switch.ll  | 227 ++++++++++++++++++
 2 files changed, 309 insertions(+), 76 deletions(-)
 create mode 100644 llvm/test/Transforms/CodeGenPrepare/ucmp-switch.ll

diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index d10ea6604b27d..a080f78e85e24 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -444,7 +444,7 @@ class CodeGenPrepare {
   bool optimizeShuffleVectorInst(ShuffleVectorInst *SVI);
   bool optimizeSwitchType(SwitchInst *SI);
   bool optimizeSwitchPhiConstants(SwitchInst *SI);
-  bool optimizeSwitchInst(SwitchInst *SI);
+  bool optimizeSwitchInst(SwitchInst *SI, ModifyDT &ModifiedDT);
   bool optimizeExtractElementInst(Instruction *Inst);
   bool dupRetToEnableTailCallOpts(BasicBlock *BB, ModifyDT &ModifiedDT);
   bool fixupDbgVariableRecord(DbgVariableRecord &I);
@@ -8125,9 +8125,88 @@ bool CodeGenPrepare::optimizeSwitchPhiConstants(SwitchInst *SI) {
   return Changed;
 }
 
-bool CodeGenPrepare::optimizeSwitchInst(SwitchInst *SI) {
+// Converts switch(ucmp(x,y)) into direct branches to avoid materializing the
+// -1/0/1 value.
+static bool optimizeSwitchOnCompare(SwitchInst *SI, Function &F,
+                                    const TargetTransformInfo &TTI,
+                                    ModifyDT &ModifiedDT) {
+  Value *Cond = SI->getCondition();
+
+  if (auto *Cast = dyn_cast<CastInst>(Cond)) {
+    if (Cast->getOpcode() == Instruction::ZExt ||
+        Cast->getOpcode() == Instruction::SExt) {
+      Cond = Cast->getOperand(0);
+    }
+  }
+
+  auto *II = dyn_cast<CmpIntrinsic>(Cond);
+  if (!II)
+    return false;
+
+  Value *LHS = II->getOperand(0);
+  Value *RHS = II->getOperand(1);
+
+  // 1. Map Targets (-1 -> Less, 0 -> Equal, 1 -> Greater)
+  BasicBlock *DestLess = SI->getDefaultDest();
+  BasicBlock *DestEqual = SI->getDefaultDest();
+  BasicBlock *DestGreater = SI->getDefaultDest();
+
+  unsigned IntrinsicWidth = II->getType()->getScalarSizeInBits();
+
+  for (auto Case : SI->cases()) {
+    APInt Val = Case.getCaseValue()->getValue();
+
+    if (Val.getBitWidth() > IntrinsicWidth)
+      Val = Val.trunc(IntrinsicWidth);
+
+    if (Val.isAllOnes())
+      DestLess = Case.getCaseSuccessor();
+    else if (Val.isZero())
+      DestEqual = Case.getCaseSuccessor();
+    else if (Val.isOne())
+      DestGreater = Case.getCaseSuccessor();
+  }
+
+  // Cases with common destinations will be simplified by
+  // simplifySwitchOfCmpIntrinsic
+  if (DestLess == DestEqual || DestLess == DestGreater ||
+      DestEqual == DestGreater)
+    return false;
+
+  BasicBlock *HeadBB = SI->getParent();
+  LLVMContext &Ctx = F.getContext();
+
+  // Create the intermediate block
+  BasicBlock *CheckEqBB = BasicBlock::Create(Ctx, "check.eq", &F);
+
+  CheckEqBB->moveAfter(HeadBB);
+
+  // Compare Less
+  IRBuilder<> Builder(SI);
+  CmpInst::Predicate PredLess = II->getLTPredicate();
+  Value *CmpLess = Builder.CreateICmp(PredLess, LHS, RHS, "cmp.less");
+
+  // Replace Switch with Branch
+  BranchInst::Create(DestLess, CheckEqBB, CmpLess, HeadBB);
+  SI->eraseFromParent();
+
+  // Compare Equal
+  Builder.SetInsertPoint(CheckEqBB);
+  Value *CmpEq = Builder.CreateICmp(ICmpInst::ICMP_EQ, LHS, RHS, "cmp.eq");
+  BranchInst::Create(DestEqual, DestGreater, CmpEq, CheckEqBB);
+
+  for (BasicBlock *Dest : {DestEqual, DestGreater})
+    Dest->replacePhiUsesWith(HeadBB, CheckEqBB);
+
+  ModifiedDT = ModifyDT::ModifyInstDT;
+  return true;
+}
+
+bool CodeGenPrepare::optimizeSwitchInst(SwitchInst *SI, ModifyDT &ModifiedDT) {
   bool Changed = optimizeSwitchType(SI);
   Changed |= optimizeSwitchPhiConstants(SI);
+  if (optimizeSwitchOnCompare(SI, *SI->getFunction(), *TTI, ModifiedDT))
+    return true;
   return Changed;
 }
 
@@ -9052,7 +9131,7 @@ bool CodeGenPrepare::optimizeInst(Instruction *I, ModifyDT &ModifiedDT) {
   case Instruction::ShuffleVector:
     return optimizeShuffleVectorInst(cast<ShuffleVectorInst>(I));
   case Instruction::Switch:
-    return optimizeSwitchInst(cast<SwitchInst>(I));
+    return optimizeSwitchInst(cast<SwitchInst>(I), ModifiedDT);
   case Instruction::ExtractElement:
     return optimizeExtractElementInst(cast<ExtractElementInst>(I));
   case Instruction::Br:
@@ -9084,81 +9163,8 @@ bool CodeGenPrepare::makeBitReverse(Instruction &I) {
 // In this pass we look for GEP and cast instructions that are used
 // across basic blocks and rewrite them to improve basic-block-at-a-time
 // selection.
-// Converts switch(ucmp(x,y)) into direct branches to avoid materializing the
-// -1/0/1 value.
-static bool optimizeSwitchOnCompare(SwitchInst *SI, Function &F,
-                                    const TargetTransformInfo &TTI) {
-  Value *Cond = SI->getCondition();
-  auto *II = dyn_cast<IntrinsicInst>(Cond);
-  if (!II || (II->getIntrinsicID() != Intrinsic::ucmp &&
-              II->getIntrinsicID() != Intrinsic::scmp))
-    return false;
-
-  bool IsSigned = II->getIntrinsicID() == Intrinsic::scmp;
-  Value *LHS = II->getOperand(0);
-  Value *RHS = II->getOperand(1);
-
-  // 1. Map Targets (-1 -> Less, 0 -> Equal, 1 -> Greater)
-  BasicBlock *DestLess = SI->getDefaultDest();
-  BasicBlock *DestEqual = SI->getDefaultDest();
-  BasicBlock *DestGreater = SI->getDefaultDest();
-
-  for (auto Case : SI->cases()) {
-    int64_t Val = Case.getCaseValue()->getSExtValue();
-    if (Val == -1)
-      DestLess = Case.getCaseSuccessor();
-    else if (Val == 0)
-      DestEqual = Case.getCaseSuccessor();
-    else if (Val == 1)
-      DestGreater = Case.getCaseSuccessor();
-  }
-  
-  if (DestLess == DestEqual || DestLess == DestGreater || 
-      DestEqual == DestGreater)
-    return false;
-  
-  BasicBlock *HeadBB = SI->getParent();
-  LLVMContext &Ctx = F.getContext();
-
-  // Create the intermediate block
-  BasicBlock *CheckEqBB = BasicBlock::Create(Ctx, "check.eq", &F);
-
-  CheckEqBB->moveAfter(HeadBB);
-
-  // Compare Less
-  IRBuilder<> Builder(SI);
-  CmpInst::Predicate PredLess =
-      IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
-  Value *CmpLess = Builder.CreateICmp(PredLess, LHS, RHS, "cmp.less");
-
-  // Replace Switch with Branch
-  BranchInst::Create(DestLess, CheckEqBB, CmpLess, HeadBB);
-  SI->eraseFromParent();
-
-  // Compare Equal
-  Builder.SetInsertPoint(CheckEqBB);
-  Value *CmpEq = Builder.CreateICmp(ICmpInst::ICMP_EQ, LHS, RHS, "cmp.eq");
-  BranchInst::Create(DestEqual, DestGreater, CmpEq, CheckEqBB);
-  
-  for (BasicBlock *Dest : {DestEqual, DestGreater}) {
-    for (PHINode &PN : Dest->phis()) {
-      int Idx = PN.getBasicBlockIndex(HeadBB);
-      if (Idx != -1)
-        PN.setIncomingBlock(Idx, CheckEqBB);
-    }
-  }
-  return true;
-}
-
 bool CodeGenPrepare::optimizeBlock(BasicBlock &BB, ModifyDT &ModifiedDT) {
   SunkAddrs.clear();
-  
-  if (auto *SI = dyn_cast<SwitchInst>(BB.getTerminator())) {
-    if (optimizeSwitchOnCompare(SI, *BB.getParent(), *TTI)) {
-      ModifiedDT = ModifyDT::ModifyInstDT;
-      return true;
-    }
-  }
 
   bool MadeChange = false;
 
diff --git a/llvm/test/Transforms/CodeGenPrepare/ucmp-switch.ll b/llvm/test/Transforms/CodeGenPrepare/ucmp-switch.ll
new file mode 100644
index 0000000000000..96aa3abdbe0fc
--- /dev/null
+++ b/llvm/test/Transforms/CodeGenPrepare/ucmp-switch.ll
@@ -0,0 +1,227 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -codegenprepare -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s
+
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+; --- DECLARATIONS ---
+declare void @func_less()
+declare void @func_equal()
+declare void @func_greater()
+declare i8 @llvm.ucmp.i8.i8(i8, i8)
+declare i8 @llvm.scmp.i8.i8(i8, i8)
+declare i8 @llvm.ucmp.i8.i47(i47, i47)
+
+; --- BASIC TESTS ---
+
+define void @test_ucmp_switch(i8 noundef %x, i8 noundef %y) {
+; CHECK-LABEL: @test_ucmp_switch(
+; CHECK-NEXT:  start:
+; CHECK:         [[CMP_LESS:%.*]] = icmp ult i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    br i1 [[CMP_LESS]], label [[LESS:%.*]], label [[CHECK_EQ:%.*]]
+; CHECK:       check.eq:
+; CHECK-NEXT:    [[CMP_EQ:%.*]] = icmp eq i8 [[X]], [[Y]]
+; CHECK-NEXT:    br i1 [[CMP_EQ]], label [[EQUAL:%.*]], label [[GREATER:%.*]]
+; CHECK:       less:
+; CHECK-NEXT:    tail call void @func_less()
+; CHECK-NEXT:    ret void
+; CHECK:       equal:
+; CHECK-NEXT:    tail call void @func_equal()
+; CHECK-NEXT:    ret void
+; CHECK:       greater:
+; CHECK-NEXT:    tail call void @func_greater()
+; CHECK-NEXT:    ret void
+;
+start:
+  %val = call i8 @llvm.ucmp.i8.i8(i8 %x, i8 %y)
+  %ext = zext i8 %val to i32
+  switch i32 %ext, label %unreachable_block [
+    i32 255, label %less
+    i32 0, label %equal
+    i32 1, label %greater
+  ]
+
+unreachable_block:
+  unreachable
+
+less:
+  tail call void @func_less()
+  ret void
+
+equal:
+  tail call void @func_equal()
+  ret void
+
+greater:
+  tail call void @func_greater()
+  ret void
+}
+
+define void @test_scmp_switch(i8 noundef %x, i8 noundef %y) {
+; CHECK-LABEL: @test_scmp_switch(
+; CHECK-NEXT:  start:
+; CHECK:         [[CMP_LESS:%.*]] = icmp slt i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    br i1 [[CMP_LESS]], label [[LESS:%.*]], label [[CHECK_EQ:%.*]]
+; CHECK:       check.eq:
+; CHECK-NEXT:    [[CMP_EQ:%.*]] = icmp eq i8 [[X]], [[Y]]
+; CHECK-NEXT:    br i1 [[CMP_EQ]], label [[EQUAL:%.*]], label [[GREATER:%.*]]
+;
+start:
+  %val = call i8 @llvm.scmp.i8.i8(i8 %x, i8 %y)
+  %ext = zext i8 %val to i32
+  switch i32 %ext, label %unreachable_block [
+    i32 255, label %less
+    i32 0, label %equal
+    i32 1, label %greater
+  ]
+
+unreachable_block:
+  unreachable
+
+less:
+  tail call void @func_less()
+  ret void
+
+equal:
+  tail call void @func_equal()
+  ret void
+
+greater:
+  tail call void @func_greater()
+  ret void
+}
+
+; --- STRESS TESTS (Corner Cases) ---
+
+define void @stress_missing_case_default(i8 %x, i8 %y) {
+; CHECK-LABEL: @stress_missing_case_default(
+; CHECK-NEXT:  start:
+; CHECK:         [[CMP_LESS:%.*]] = icmp ult i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    br i1 [[CMP_LESS]], label [[LESS:%.*]], label [[CHECK_EQ:%.*]]
+; CHECK:       check.eq:
+; CHECK-NEXT:    [[CMP_EQ:%.*]] = icmp eq i8 [[X]], [[Y]]
+; CHECK-NEXT:    br i1 [[CMP_EQ]], label [[DEFAULT:%.*]], label [[GREATER:%.*]]
+;
+start:
+  %val = call i8 @llvm.ucmp.i8.i8(i8 %x, i8 %y)
+  %ext = zext i8 %val to i32
+  switch i32 %ext, label %default [
+    i32 255, label %less    ; -1 goes to Less
+    i32 1, label %greater   ; 1 goes to Greater
+    ; 0 is MISSING! It should go to %default.
+  ]
+
+less:
+  tail call void @func_less()
+  ret void
+
+greater:
+  tail call void @func_greater()
+  ret void
+
+default: ; acts as Equal
+  tail call void @func_equal()
+  ret void
+}
+
+define void @stress_weird_types(i47 %x, i47 %y) {
+; CHECK-LABEL: @stress_weird_types(
+; CHECK-NEXT:  start:
+; CHECK:         [[CMP_LESS:%.*]] = icmp ult i47 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    br i1 [[CMP_LESS]], label [[LESS:%.*]], label [[CHECK_EQ:%.*]]
+; CHECK:       check.eq:
+; CHECK-NEXT:    [[CMP_EQ:%.*]] = icmp eq i47 [[X]], [[Y]]
+; CHECK-NEXT:    br i1 [[CMP_EQ]], label [[EQUAL:%.*]], label [[GREATER:%.*]]
+;
+start:
+  %val = call i8 @llvm.ucmp.i8.i47(i47 %x, i47 %y)
+  %ext = zext i8 %val to i32
+  switch i32 %ext, label %unreachable_block [
+    i32 255, label %less
+    i32 0, label %equal
+    i32 1, label %greater
+  ]
+
+unreachable_block:
+  unreachable
+
+less:
+  tail call void @func_less()
+  ret void
+
+equal:
+  tail call void @func_equal()
+  ret void
+
+greater:
+  tail call void @func_greater()
+  ret void
+}
+
+define void @stress_sext_scmp(i8 %x, i8 %y) {
+; CHECK-LABEL: @stress_sext_scmp(
+; CHECK-NEXT:  start:
+; CHECK:         [[CMP_LESS:%.*]] = icmp slt i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    br i1 [[CMP_LESS]], label [[LESS:%.*]], label [[CHECK_EQ:%.*]]
+; CHECK:       check.eq:
+; CHECK-NEXT:    [[CMP_EQ:%.*]] = icmp eq i8 [[X]], [[Y]]
+; CHECK-NEXT:    br i1 [[CMP_EQ]], label [[EQUAL:%.*]], label [[GREATER:%.*]]
+;
+start:
+  %val = call i8 @llvm.scmp.i8.i8(i8 %x, i8 %y)
+  %ext = sext i8 %val to i32  ; <--- Testing SEXT here
+  switch i32 %ext, label %unreachable_block [
+    i32 -1, label %less       ; Case value is -1 (0xFFFFFFFF)
+    i32 0, label %equal
+    i32 1, label %greater
+  ]
+
+unreachable_block:
+  unreachable
+
+less:
+  tail call void @func_less()
+  ret void
+
+equal:
+  tail call void @func_equal()
+  ret void
+
+greater:
+  tail call void @func_greater()
+  ret void
+}
+
+define void @stress_no_cast(i8 %x, i8 %y) {
+; CHECK-LABEL: @stress_no_cast(
+; CHECK-NEXT:  start:
+; CHECK:         [[CMP_LESS:%.*]] = icmp ult i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    br i1 [[CMP_LESS]], label [[LESS:%.*]], label [[CHECK_EQ:%.*]]
+; CHECK:       check.eq:
+; CHECK-NEXT:    [[CMP_EQ:%.*]] = icmp eq i8 [[X]], [[Y]]
+; CHECK-NEXT:    br i1 [[CMP_EQ]], label [[EQUAL:%.*]], label [[GREATER:%.*]]
+;
+start:
+  %val = call i8 @llvm.ucmp.i8.i8(i8 %x, i8 %y)
+  ; NO CAST HERE! Switch directly on the intrinsic result.
+  switch i8 %val, label %unreachable_block [
+    i8 255, label %less
+    i8 0, label %equal
+    i8 1, label %greater
+  ]
+
+unreachable_block:
+  unreachable
+
+less:
+  tail call void @func_less()
+  ret void
+
+equal:
+  tail call void @func_equal()
+  ret void
+
+greater:
+  tail call void @func_greater()
+  ret void
+}
\ No newline at end of file

>From 7528ff55d71050c81d6bd3caf9912ee3671bab2d Mon Sep 17 00:00:00 2001
From: kamini08 <kaminibanait03 at gmail.com>
Date: Thu, 29 Jan 2026 01:24:38 +0530
Subject: [PATCH 4/4] [CodeGen] fix bugs in optimizeSwitchOnCompare
 transformation

---
 llvm/lib/CodeGen/CodeGenPrepare.cpp           | 19 +++++-
 .../Transforms/CodeGenPrepare/ucmp-switch.ll  | 63 +++++++++++++++++++
 2 files changed, 79 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index a080f78e85e24..0a80bc85afbf6 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -8131,11 +8131,14 @@ static bool optimizeSwitchOnCompare(SwitchInst *SI, Function &F,
                                     const TargetTransformInfo &TTI,
                                     ModifyDT &ModifiedDT) {
   Value *Cond = SI->getCondition();
+  Instruction::CastOps CastOp = Instruction::CastOps::CastOpsEnd;
 
   if (auto *Cast = dyn_cast<CastInst>(Cond)) {
-    if (Cast->getOpcode() == Instruction::ZExt ||
-        Cast->getOpcode() == Instruction::SExt) {
+    CastOp = Cast->getOpcode();
+    if (CastOp == Instruction::ZExt || CastOp == Instruction::SExt) {
       Cond = Cast->getOperand(0);
+    } else {
+      CastOp = Instruction::CastOps::CastOpsEnd;
     }
   }
 
@@ -8156,6 +8159,16 @@ static bool optimizeSwitchOnCompare(SwitchInst *SI, Function &F,
   for (auto Case : SI->cases()) {
     APInt Val = Case.getCaseValue()->getValue();
 
+    if (CastOp == Instruction::ZExt) {
+      if (Val.getActiveBits() > IntrinsicWidth)
+        continue; // Unreachable case
+    }
+    else if (CastOp == Instruction::SExt) {
+      unsigned MinSignedBits = Val.getBitWidth() - Val.getNumSignBits() + 1;
+      if (MinSignedBits > IntrinsicWidth)
+        continue; // Unreachable case
+    }
+
     if (Val.getBitWidth() > IntrinsicWidth)
       Val = Val.trunc(IntrinsicWidth);
 
@@ -8198,7 +8211,7 @@ static bool optimizeSwitchOnCompare(SwitchInst *SI, Function &F,
   for (BasicBlock *Dest : {DestEqual, DestGreater})
     Dest->replacePhiUsesWith(HeadBB, CheckEqBB);
 
-  ModifiedDT = ModifyDT::ModifyInstDT;
+  ModifiedDT = ModifyDT::ModifyBBDT;
   return true;
 }
 
diff --git a/llvm/test/Transforms/CodeGenPrepare/ucmp-switch.ll b/llvm/test/Transforms/CodeGenPrepare/ucmp-switch.ll
index 96aa3abdbe0fc..5e2e2b0bbda52 100644
--- a/llvm/test/Transforms/CodeGenPrepare/ucmp-switch.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/ucmp-switch.ll
@@ -224,4 +224,67 @@ equal:
 greater:
   tail call void @func_greater()
   ret void
+}
+
+; --- PHI NODE HANDLING TEST ---
+
+define i32 @test_phi_nodes(i8 %x, i8 %y, i32 %a, i32 %b, i1 %cond) {
+; CHECK-LABEL: @test_phi_nodes(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[COND:%.*]], label [[PRE_SWITCH:%.*]], label [[FALLTHROUGH:%.*]]
+;
+; CHECK:       pre.switch:
+; CHECK-NEXT:    [[VAL:%.*]] = call i8 @llvm.ucmp.i8.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
+; CHECK-NEXT:    [[EXT:%.*]] = zext i8 [[VAL]] to i32
+; CHECK-NEXT:    [[CMP_LESS:%.*]] = icmp ult i8 [[X]], [[Y]]
+;                ;; Verify 'less' branch still comes from pre.switch
+; CHECK-NEXT:    br i1 [[CMP_LESS]], label [[LESS:%.*]], label [[CHECK_EQ:%.*]]
+;
+; CHECK:       check.eq:
+; CHECK-NEXT:    [[CMP_EQ:%.*]] = icmp eq i8 [[X]], [[Y]]
+;                ;; Verify equal/greater come from check.eq
+; CHECK-NEXT:    br i1 [[CMP_EQ]], label [[EQUAL:%.*]], label [[GREATER:%.*]]
+;
+; CHECK:       less:
+;                ;; CRITICAL: Verify predecessor is still [[PRE_SWITCH]]
+; CHECK-NEXT:    [[PHI_LESS:%.*]] = phi i32 [ [[A:%.*]], [[PRE_SWITCH]] ], [ [[B:%.*]], %entry ]
+; CHECK-NEXT:    ret i32 [[PHI_LESS]]
+;
+; CHECK:       equal:
+;                ;; CRITICAL: Verify predecessor updated to [[CHECK_EQ]]
+; CHECK-NEXT:    ret i32 [[A]]
+;
+; CHECK:       greater:
+;                ;; CRITICAL: Verify predecessor updated to [[CHECK_EQ]]
+; CHECK-NEXT:    ret i32 [[B]]
+;
+entry:
+  br i1 %cond, label %pre.switch, label %fallthrough
+
+pre.switch:
+  %val = call i8 @llvm.ucmp.i8.i8(i8 %x, i8 %y)
+  %ext = zext i8 %val to i32
+  switch i32 %ext, label %unreachable_block [
+    i32 255, label %less
+    i32 0, label %equal
+    i32 1, label %greater
+  ]
+
+fallthrough:
+  br label %less
+
+unreachable_block:
+  unreachable
+
+less:
+  %phi.less = phi i32 [ %a, %pre.switch ], [ %b, %fallthrough ]
+  ret i32 %phi.less
+
+equal:
+  %phi.equal = phi i32 [ %a, %pre.switch ]
+  ret i32 %phi.equal
+
+greater:
+  %phi.greater = phi i32 [ %b, %pre.switch ]
+  ret i32 %phi.greater
 }
\ No newline at end of file



More information about the llvm-commits mailing list