[llvm] [SimplifyCFG] Fold switch over ucmp/scmp to icmp and br (PR #105636)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 22 03:09:03 PDT 2024


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/105636

If we switch over ucmp/scmp and have two switch cases going to the same destination, we can convert into icmp+br.

Fixes https://github.com/llvm/llvm-project/issues/105632.

>From 1620a6ee2ac01e9fe989faf645c13fa1a2b606d0 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 22 Aug 2024 11:49:46 +0200
Subject: [PATCH] [SimplifyCFG] Fold switch over ucmp/scmp to icmp and br

If we switch over ucmp/scmp and have two switch cases going to
the same destination, we can convert into icmp+br.

Fixes https://github.com/llvm/llvm-project/issues/105632.
---
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp     | 56 ++++++++++++++++
 .../Transforms/SimplifyCFG/switch-on-cmp.ll   | 65 ++++++-------------
 2 files changed, 75 insertions(+), 46 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 00efd3c0eb72ec..d7295d89078b6c 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -7131,6 +7131,59 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
   return true;
 }
 
+/// Fold switch over ucmp/scmp intrinsic to br if two of the switch arms have
+/// the same destination.
+static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI,
+                                         IRBuilderBase &Builder) {
+  auto *Cmp = dyn_cast<CmpIntrinsic>(SI->getCondition());
+  if (!Cmp || !Cmp->hasOneUse() || SI->getNumCases() != 2)
+    return false;
+
+  // Find which of 1, 0 or -1 is missing.
+  SmallSet<int64_t, 3> Missing;
+  Missing.insert(1);
+  Missing.insert(0);
+  Missing.insert(-1);
+  BasicBlock *Succ = nullptr;
+  for (auto &Case : SI->cases()) {
+    std::optional<int64_t> Val = Case.getCaseValue()->getValue().trySExtValue();
+    if (!Val)
+      return false;
+    if (!Missing.erase(*Val))
+      return false;
+    if (Succ && Succ != Case.getCaseSuccessor())
+      return false;
+    Succ = Case.getCaseSuccessor();
+  }
+
+  // Determine predicate for the missing case.
+  ICmpInst::Predicate Pred;
+  assert(Missing.size() == 1 && "Should have one case left");
+  switch (*Missing.begin()) {
+  case 1:
+    Pred = ICmpInst::ICMP_UGT;
+    break;
+  case 0:
+    Pred = ICmpInst::ICMP_EQ;
+    break;
+  case -1:
+    Pred = ICmpInst::ICMP_ULT;
+    break;
+  }
+  if (Cmp->isSigned())
+    Pred = ICmpInst::getSignedPredicate(Pred);
+
+  // The dominator tree does not change, because it treats multi-edges like
+  // a single edge anyway.
+  Builder.SetInsertPoint(SI->getIterator());
+  Value *ICmp = Builder.CreateICmp(Pred, Cmp->getLHS(), Cmp->getRHS());
+  Builder.CreateCondBr(ICmp, SI->getDefaultDest(), Succ);
+  Succ->removePredecessor(SI->getParent());
+  SI->eraseFromParent();
+  Cmp->eraseFromParent();
+  return true;
+}
+
 bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
   BasicBlock *BB = SI->getParent();
 
@@ -7163,6 +7216,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
   if (eliminateDeadSwitchCases(SI, DTU, Options.AC, DL))
     return requestResimplify();
 
+  if (simplifySwitchOfCmpIntrinsic(SI, Builder))
+    return requestResimplify();
+
   if (trySwitchToSelect(SI, Builder, DTU, DL, TTI))
     return requestResimplify();
 
diff --git a/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll b/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll
index 1ce18533d156d0..a9b7a0b9e2f096 100644
--- a/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll
+++ b/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll
@@ -4,11 +4,8 @@
 define void @ucmp_gt1(i32 %a, i32 %b) {
 ; CHECK-LABEL: define void @ucmp_gt1(
 ; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT:    [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT:      i8 -1, label %[[BB2:.*]]
-; CHECK-NEXT:      i8 0, label %[[BB2]]
-; CHECK-NEXT:    ]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ugt i32 [[A]], [[B]]
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
 ; CHECK:       [[BB1]]:
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label %[[BB2]]
@@ -32,11 +29,8 @@ bb2:
 define void @ucmp_gt2(i32 %a, i32 %b) {
 ; CHECK-LABEL: define void @ucmp_gt2(
 ; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT:    [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT:      i8 0, label %[[BB2:.*]]
-; CHECK-NEXT:      i8 -1, label %[[BB2]]
-; CHECK-NEXT:    ]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ugt i32 [[A]], [[B]]
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
 ; CHECK:       [[BB1]]:
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label %[[BB2]]
@@ -60,11 +54,8 @@ bb2:
 define void @ucmp_lt1(i32 %a, i32 %b) {
 ; CHECK-LABEL: define void @ucmp_lt1(
 ; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT:    [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    switch i8 [[RES]], label %[[BB2:.*]] [
-; CHECK-NEXT:      i8 1, label %[[BB1:.*]]
-; CHECK-NEXT:      i8 0, label %[[BB1]]
-; CHECK-NEXT:    ]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult i32 [[A]], [[B]]
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[BB2:.*]], label %[[BB1:.*]]
 ; CHECK:       [[BB1]]:
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label %[[BB2]]
@@ -88,11 +79,8 @@ bb2:
 define void @ucmp_lt2(i32 %a, i32 %b) {
 ; CHECK-LABEL: define void @ucmp_lt2(
 ; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT:    [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    switch i8 [[RES]], label %[[BB2:.*]] [
-; CHECK-NEXT:      i8 0, label %[[BB1:.*]]
-; CHECK-NEXT:      i8 1, label %[[BB1]]
-; CHECK-NEXT:    ]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult i32 [[A]], [[B]]
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[BB2:.*]], label %[[BB1:.*]]
 ; CHECK:       [[BB1]]:
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label %[[BB2]]
@@ -116,11 +104,8 @@ bb2:
 define void @ucmp_eq1(i32 %a, i32 %b) {
 ; CHECK-LABEL: define void @ucmp_eq1(
 ; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT:    [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT:      i8 -1, label %[[BB2:.*]]
-; CHECK-NEXT:      i8 1, label %[[BB2]]
-; CHECK-NEXT:    ]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i32 [[A]], [[B]]
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
 ; CHECK:       [[BB1]]:
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label %[[BB2]]
@@ -144,11 +129,8 @@ bb2:
 define void @ucmp_eq2(i32 %a, i32 %b) {
 ; CHECK-LABEL: define void @ucmp_eq2(
 ; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT:    [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT:      i8 1, label %[[BB2:.*]]
-; CHECK-NEXT:      i8 -1, label %[[BB2]]
-; CHECK-NEXT:    ]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i32 [[A]], [[B]]
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
 ; CHECK:       [[BB1]]:
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label %[[BB2]]
@@ -172,11 +154,8 @@ bb2:
 define void @scmp_gt1(i32 %a, i32 %b) {
 ; CHECK-LABEL: define void @scmp_gt1(
 ; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT:    [[RES:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT:      i8 -1, label %[[BB2:.*]]
-; CHECK-NEXT:      i8 0, label %[[BB2]]
-; CHECK-NEXT:    ]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp sgt i32 [[A]], [[B]]
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
 ; CHECK:       [[BB1]]:
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label %[[BB2]]
@@ -200,11 +179,8 @@ bb2:
 define void @scmp_gt2(i32 %a, i32 %b) {
 ; CHECK-LABEL: define void @scmp_gt2(
 ; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT:    [[RES:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT:      i8 0, label %[[BB2:.*]]
-; CHECK-NEXT:      i8 -1, label %[[BB2]]
-; CHECK-NEXT:    ]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp sgt i32 [[A]], [[B]]
+; CHECK-NEXT:    br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
 ; CHECK:       [[BB1]]:
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label %[[BB2]]
@@ -259,16 +235,13 @@ define i32 @ucmp_gt_phi(i32 %a, i32 %b) {
 ; CHECK-LABEL: define i32 @ucmp_gt_phi(
 ; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*]]:
-; CHECK-NEXT:    [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT:      i8 -1, label %[[BB2:.*]]
-; CHECK-NEXT:      i8 0, label %[[BB2]]
-; CHECK-NEXT:    ]
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp ugt i32 [[A]], [[B]]
+; CHECK-NEXT:    br i1 [[TMP0]], label %[[BB1:.*]], label %[[BB2:.*]]
 ; CHECK:       [[BB1]]:
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label %[[BB2]]
 ; CHECK:       [[BB2]]:
-; CHECK-NEXT:    [[PHI:%.*]] = phi i32 [ 0, %[[BB1]] ], [ 1, %[[ENTRY]] ], [ 1, %[[ENTRY]] ]
+; CHECK-NEXT:    [[PHI:%.*]] = phi i32 [ 0, %[[BB1]] ], [ 1, %[[ENTRY]] ]
 ; CHECK-NEXT:    ret i32 [[PHI]]
 ;
 entry:



More information about the llvm-commits mailing list