[llvm] 535d8e8 - NFC: Extract switch lowering binary tree splitting code from DAG into SwitchLoweringUtils.

Amara Emerson via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 7 07:42:34 PST 2024


Author: Amara Emerson
Date: 2024-01-07T07:42:27-08:00
New Revision: 535d8e8b92e3f8cf4107d9431012310c9a72c8d3

URL: https://github.com/llvm/llvm-project/commit/535d8e8b92e3f8cf4107d9431012310c9a72c8d3
DIFF: https://github.com/llvm/llvm-project/commit/535d8e8b92e3f8cf4107d9431012310c9a72c8d3.diff

LOG: NFC: Extract switch lowering binary tree splitting code from DAG into SwitchLoweringUtils.

This will help re-use this code with the upcoming GlobalISel implementation of
this optimization.

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
    llvm/lib/CodeGen/SwitchLoweringUtils.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h b/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
index 5d06e21737b88c..99478e9f39e226 100644
--- a/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
+++ b/llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
@@ -293,6 +293,22 @@ class SwitchLowering {
       MachineBasicBlock *Src, MachineBasicBlock *Dst,
       BranchProbability Prob = BranchProbability::getUnknown()) = 0;
 
+  /// Determine the rank by weight of CC in [First,Last]. If CC has more weight
+  /// than each cluster in the range, its rank is 0.
+  unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First,
+                           CaseClusterIt Last);
+
+  struct SplitWorkItemInfo {
+    CaseClusterIt LastLeft;
+    CaseClusterIt FirstRight;
+    BranchProbability LeftProb;
+    BranchProbability RightProb;
+  };
+  /// Compute information to balance the tree based on branch probabilities to
+  /// create a near-optimal (in terms of search time given key frequency) binary
+  /// search tree. See e.g. Kurt Mehlhorn "Nearly Optimal Binary Search Trees"
+  /// (1975).
+  SplitWorkItemInfo computeSplitWorkItemInfo(const SwitchWorkListItem &W);
   virtual ~SwitchLowering() = default;
 
 private:

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 78ebd2d33459a7..1ae682eaf2511c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -11639,92 +11639,16 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
   }
 }
 
-unsigned SelectionDAGBuilder::caseClusterRank(const CaseCluster &CC,
-                                              CaseClusterIt First,
-                                              CaseClusterIt Last) {
-  return std::count_if(First, Last + 1, [&](const CaseCluster &X) {
-    if (X.Prob != CC.Prob)
-      return X.Prob > CC.Prob;
-
-    // Ties are broken by comparing the case value.
-    return X.Low->getValue().slt(CC.Low->getValue());
-  });
-}
-
 void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList,
                                         const SwitchWorkListItem &W,
                                         Value *Cond,
                                         MachineBasicBlock *SwitchMBB) {
   assert(W.FirstCluster->Low->getValue().slt(W.LastCluster->Low->getValue()) &&
          "Clusters not sorted?");
-
   assert(W.LastCluster - W.FirstCluster + 1 >= 2 && "Too small to split!");
 
-  // Balance the tree based on branch probabilities to create a near-optimal (in
-  // terms of search time given key frequency) binary search tree. See e.g. Kurt
-  // Mehlhorn "Nearly Optimal Binary Search Trees" (1975).
-  CaseClusterIt LastLeft = W.FirstCluster;
-  CaseClusterIt FirstRight = W.LastCluster;
-  auto LeftProb = LastLeft->Prob + W.DefaultProb / 2;
-  auto RightProb = FirstRight->Prob + W.DefaultProb / 2;
-
-  // Move LastLeft and FirstRight towards each other from opposite directions to
-  // find a partitioning of the clusters which balances the probability on both
-  // sides. If LeftProb and RightProb are equal, alternate which side is
-  // taken to ensure 0-probability nodes are distributed evenly.
-  unsigned I = 0;
-  while (LastLeft + 1 < FirstRight) {
-    if (LeftProb < RightProb || (LeftProb == RightProb && (I & 1)))
-      LeftProb += (++LastLeft)->Prob;
-    else
-      RightProb += (--FirstRight)->Prob;
-    I++;
-  }
-
-  while (true) {
-    // Our binary search tree 
diff ers from a typical BST in that ours can have up
-    // to three values in each leaf. The pivot selection above doesn't take that
-    // into account, which means the tree might require more nodes and be less
-    // efficient. We compensate for this here.
-
-    unsigned NumLeft = LastLeft - W.FirstCluster + 1;
-    unsigned NumRight = W.LastCluster - FirstRight + 1;
-
-    if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) {
-      // If one side has less than 3 clusters, and the other has more than 3,
-      // consider taking a cluster from the other side.
-
-      if (NumLeft < NumRight) {
-        // Consider moving the first cluster on the right to the left side.
-        CaseCluster &CC = *FirstRight;
-        unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
-        unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
-        if (LeftSideRank <= RightSideRank) {
-          // Moving the cluster to the left does not demote it.
-          ++LastLeft;
-          ++FirstRight;
-          continue;
-        }
-      } else {
-        assert(NumRight < NumLeft);
-        // Consider moving the last element on the left to the right side.
-        CaseCluster &CC = *LastLeft;
-        unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
-        unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
-        if (RightSideRank <= LeftSideRank) {
-          // Moving the cluster to the right does not demot it.
-          --LastLeft;
-          --FirstRight;
-          continue;
-        }
-      }
-    }
-    break;
-  }
-
-  assert(LastLeft + 1 == FirstRight);
-  assert(LastLeft >= W.FirstCluster);
-  assert(FirstRight <= W.LastCluster);
+  auto [LastLeft, FirstRight, LeftProb, RightProb] =
+      SL->computeSplitWorkItemInfo(W);
 
   // Use the first element on the right as pivot since we will make less-than
   // comparisons against it.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
index 2e102c002c093e..6dcb8c816ad080 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
@@ -200,12 +200,6 @@ class SelectionDAGBuilder {
   /// create.
   unsigned SDNodeOrder;
 
-  /// Determine the rank by weight of CC in [First,Last]. If CC has more weight
-  /// than each cluster in the range, its rank is 0.
-  unsigned caseClusterRank(const SwitchCG::CaseCluster &CC,
-                           SwitchCG::CaseClusterIt First,
-                           SwitchCG::CaseClusterIt Last);
-
   /// Emit comparison and split W into two subtrees.
   void splitWorkItem(SwitchCG::SwitchWorkList &WorkList,
                      const SwitchCG::SwitchWorkListItem &W, Value *Cond,

diff  --git a/llvm/lib/CodeGen/SwitchLoweringUtils.cpp b/llvm/lib/CodeGen/SwitchLoweringUtils.cpp
index 7982d80353bd40..8922fa5898133a 100644
--- a/llvm/lib/CodeGen/SwitchLoweringUtils.cpp
+++ b/llvm/lib/CodeGen/SwitchLoweringUtils.cpp
@@ -494,3 +494,84 @@ void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
   }
   Clusters.resize(DstIndex);
 }
+
+unsigned SwitchCG::SwitchLowering::caseClusterRank(const CaseCluster &CC,
+                                                   CaseClusterIt First,
+                                                   CaseClusterIt Last) {
+  return std::count_if(First, Last + 1, [&](const CaseCluster &X) {
+    if (X.Prob != CC.Prob)
+      return X.Prob > CC.Prob;
+
+    // Ties are broken by comparing the case value.
+    return X.Low->getValue().slt(CC.Low->getValue());
+  });
+}
+
+llvm::SwitchCG::SwitchLowering::SplitWorkItemInfo
+SwitchCG::SwitchLowering::computeSplitWorkItemInfo(
+    const SwitchWorkListItem &W) {
+  CaseClusterIt LastLeft = W.FirstCluster;
+  CaseClusterIt FirstRight = W.LastCluster;
+  auto LeftProb = LastLeft->Prob + W.DefaultProb / 2;
+  auto RightProb = FirstRight->Prob + W.DefaultProb / 2;
+
+  // Move LastLeft and FirstRight towards each other from opposite directions to
+  // find a partitioning of the clusters which balances the probability on both
+  // sides. If LeftProb and RightProb are equal, alternate which side is
+  // taken to ensure 0-probability nodes are distributed evenly.
+  unsigned I = 0;
+  while (LastLeft + 1 < FirstRight) {
+    if (LeftProb < RightProb || (LeftProb == RightProb && (I & 1)))
+      LeftProb += (++LastLeft)->Prob;
+    else
+      RightProb += (--FirstRight)->Prob;
+    I++;
+  }
+
+  while (true) {
+    // Our binary search tree 
diff ers from a typical BST in that ours can have
+    // up to three values in each leaf. The pivot selection above doesn't take
+    // that into account, which means the tree might require more nodes and be
+    // less efficient. We compensate for this here.
+
+    unsigned NumLeft = LastLeft - W.FirstCluster + 1;
+    unsigned NumRight = W.LastCluster - FirstRight + 1;
+
+    if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) {
+      // If one side has less than 3 clusters, and the other has more than 3,
+      // consider taking a cluster from the other side.
+
+      if (NumLeft < NumRight) {
+        // Consider moving the first cluster on the right to the left side.
+        CaseCluster &CC = *FirstRight;
+        unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
+        unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
+        if (LeftSideRank <= RightSideRank) {
+          // Moving the cluster to the left does not demote it.
+          ++LastLeft;
+          ++FirstRight;
+          continue;
+        }
+      } else {
+        assert(NumRight < NumLeft);
+        // Consider moving the last element on the left to the right side.
+        CaseCluster &CC = *LastLeft;
+        unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
+        unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
+        if (RightSideRank <= LeftSideRank) {
+          // Moving the cluster to the right does not demot it.
+          --LastLeft;
+          --FirstRight;
+          continue;
+        }
+      }
+    }
+    break;
+  }
+
+  assert(LastLeft + 1 == FirstRight);
+  assert(LastLeft >= W.FirstCluster);
+  assert(FirstRight <= W.LastCluster);
+
+  return SplitWorkItemInfo{LastLeft, FirstRight, LeftProb, RightProb};
+}
\ No newline at end of file


        


More information about the llvm-commits mailing list