[llvm] [SimplifyCFG] Fold switch over ucmp/scmp to icmp and br (PR #105636)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 22 06:42:58 PDT 2024


================
@@ -7131,6 +7131,117 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
   return true;
 }
 
+/// Fold switch over ucmp/scmp intrinsic to br if two of the switch arms have
+/// the same destination.
+static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI, IRBuilderBase &Builder,
+                                         DomTreeUpdater *DTU) {
+  auto *Cmp = dyn_cast<CmpIntrinsic>(SI->getCondition());
+  if (!Cmp || !Cmp->hasOneUse())
+    return false;
+
+  SmallVector<uint32_t, 4> Weights;
+  bool HasWeights = extractBranchWeights(getBranchWeightMDNode(*SI), Weights);
+  if (!HasWeights)
+    Weights.resize(4); // Avoid checking HasWeights everywhere.
+
+  // Normalize to [us]cmp == Res ? Succ : OtherSucc.
+  int64_t Res;
+  BasicBlock *Succ, *OtherSucc;
+  uint32_t SuccWeight = 0, OtherSuccWeight = 0;
+  BasicBlock *Unreachable = nullptr;
+
+  if (SI->getNumCases() == 2) {
+    // Find which of 1, 0 or -1 is missing (handled by default dest).
+    SmallSet<int64_t, 3> Missing;
+    Missing.insert(1);
+    Missing.insert(0);
+    Missing.insert(-1);
+
+    Succ = SI->getDefaultDest();
+    SuccWeight = Weights[0];
+    OtherSucc = nullptr;
+    for (auto &Case : SI->cases()) {
+      std::optional<int64_t> Val =
+          Case.getCaseValue()->getValue().trySExtValue();
+      if (!Val)
+        return false;
+      if (!Missing.erase(*Val))
+        return false;
+      if (OtherSucc && OtherSucc != Case.getCaseSuccessor())
+        return false;
+      OtherSucc = Case.getCaseSuccessor();
+      OtherSuccWeight += Weights[Case.getSuccessorIndex()];
+    }
+
+    assert(Missing.size() == 1 && "Should have one case left");
+    Res = *Missing.begin();
+  } else if (SI->getNumCases() == 3 && SI->defaultDestUndefined()) {
+    // Normalize so that Succ is taken once and OtherSucc twice.
+    Unreachable = SI->getDefaultDest();
+    Succ = OtherSucc = nullptr;
+    for (auto &Case : SI->cases()) {
+      BasicBlock *NewSucc = Case.getCaseSuccessor();
+      uint32_t Weight = Weights[Case.getSuccessorIndex()];
+      if (!OtherSucc || OtherSucc == NewSucc) {
+        OtherSucc = NewSucc;
+        OtherSuccWeight += Weight;
+      } else if (!Succ) {
+        Succ = NewSucc;
+        SuccWeight = Weight;
+      } else if (Succ == NewSucc) {
+        std::swap(Succ, OtherSucc);
+        std::swap(SuccWeight, OtherSuccWeight);
+      } else
+        return false;
+    }
+    for (auto &Case : SI->cases()) {
+      std::optional<int64_t> Val =
+          Case.getCaseValue()->getValue().trySExtValue();
+      if (!Val || (Val != 1 && Val != 0 && Val != -1))
+        return false;
+      if (Case.getCaseSuccessor() == Succ)
+        Res = *Val;
----------------
dtcxzyw wrote:

```suggestion
      if (Case.getCaseSuccessor() == Succ) {
        Res = *Val;
        break;
      }
```

https://github.com/llvm/llvm-project/pull/105636


More information about the llvm-commits mailing list