[llvm] [WIP][Inline] Rewrite switch's handling of the default branch (PR #85160)

Quentin Dian via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 14 03:37:00 PDT 2024


https://github.com/DianQK updated https://github.com/llvm/llvm-project/pull/85160

>From b74b7b125c3243a2113b57fa1b39017918b804c8 Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Thu, 22 Feb 2024 20:31:44 +0800
Subject: [PATCH 1/3] [Inline] Consider the default branch when transforming to
 comparison

---
 llvm/lib/Analysis/InlineCost.cpp | 25 +++++++++++++++----------
 1 file changed, 15 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index e55eaa55f8e947..e26fcfe8876f22 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -701,24 +701,29 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
 
   void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
                         bool DefaultDestUndefined) override {
-    if (!DefaultDestUndefined)
-      addCost(2 * InstrCost);
     // If suitable for a jump table, consider the cost for the table size and
     // branch to destination.
     // Maximum valid cost increased in this function.
     if (JumpTableSize) {
+      // Suppose a default branch includes one compare and one conditional
+      // branch if it's reachable.
+      if (!DefaultDestUndefined)
+        addCost(2 * InstrCost);
       int64_t JTCost =
           static_cast<int64_t>(JumpTableSize) * InstrCost + 4 * InstrCost;
       addCost(JTCost);
       return;
     }
 
-    if (NumCaseCluster <= 3) {
+    if ((NumCaseCluster + !DefaultDestUndefined) <= 4) {
       // Suppose a comparison includes one compare and one conditional branch.
-      addCost(NumCaseCluster * 2 * InstrCost);
+      // We can create one less set of instructions if the default branch is
+      // undefined.
+      addCost((NumCaseCluster - DefaultDestUndefined) * 2 * InstrCost);
       return;
     }
 
+    // FIXME: Consider the case when default branch is undefined.
     int64_t ExpectedNumberOfCompare =
         getExpectedNumberOfCompare(NumCaseCluster);
     int64_t SwitchCost = ExpectedNumberOfCompare * 2 * InstrCost;
@@ -1235,23 +1240,23 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
 
   void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
                         bool DefaultDestUndefined) override {
-    if (!DefaultDestUndefined)
-      increment(InlineCostFeatureIndex::switch_default_dest_penalty,
-                SwitchDefaultDestCostMultiplier * InstrCost);
-
     if (JumpTableSize) {
+      if (!DefaultDestUndefined)
+        increment(InlineCostFeatureIndex::switch_default_dest_penalty,
+            SwitchDefaultDestCostMultiplier * InstrCost);
       int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost +
                        JTCostMultiplier * InstrCost;
       increment(InlineCostFeatureIndex::jump_table_penalty, JTCost);
       return;
     }
 
-    if (NumCaseCluster <= 3) {
+    if ((NumCaseCluster + !DefaultDestUndefined) <= 4) {
       increment(InlineCostFeatureIndex::case_cluster_penalty,
-                NumCaseCluster * CaseClusterCostMultiplier * InstrCost);
+                (NumCaseCluster - !DefaultDestUndefined) * CaseClusterCostMultiplier * InstrCost);
       return;
     }
 
+    // FIXME: Consider the case when default branch is undefined.
     int64_t ExpectedNumberOfCompare =
         getExpectedNumberOfCompare(NumCaseCluster);
 

>From 0a525581096c889f5122f4008d2fe8eef9af071e Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Wed, 13 Mar 2024 22:04:22 +0800
Subject: [PATCH 2/3] [Inline] The jump table only requires a jump instruction

---
 llvm/lib/Analysis/InlineCost.cpp | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index e26fcfe8876f22..4b505d3a832cc5 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -709,8 +709,9 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
       // branch if it's reachable.
       if (!DefaultDestUndefined)
         addCost(2 * InstrCost);
+      // The jump table only requires a jump instruction.
       int64_t JTCost =
-          static_cast<int64_t>(JumpTableSize) * InstrCost + 4 * InstrCost;
+          static_cast<int64_t>(JumpTableSize) * InstrCost + InstrCost;
       addCost(JTCost);
       return;
     }
@@ -1157,7 +1158,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
   // FIXME: These constants are taken from the heuristic-based cost visitor.
   // These should be removed entirely in a later revision to avoid reliance on
   // heuristics in the ML inliner.
-  static constexpr int JTCostMultiplier = 4;
+  static constexpr int JTCostMultiplier = 1;
   static constexpr int CaseClusterCostMultiplier = 2;
   static constexpr int SwitchDefaultDestCostMultiplier = 2;
   static constexpr int SwitchCostMultiplier = 2;
@@ -1243,7 +1244,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
     if (JumpTableSize) {
       if (!DefaultDestUndefined)
         increment(InlineCostFeatureIndex::switch_default_dest_penalty,
-            SwitchDefaultDestCostMultiplier * InstrCost);
+                  SwitchDefaultDestCostMultiplier * InstrCost);
       int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost +
                        JTCostMultiplier * InstrCost;
       increment(InlineCostFeatureIndex::jump_table_penalty, JTCost);
@@ -1252,7 +1253,8 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
 
     if ((NumCaseCluster + !DefaultDestUndefined) <= 4) {
       increment(InlineCostFeatureIndex::case_cluster_penalty,
-                (NumCaseCluster - !DefaultDestUndefined) * CaseClusterCostMultiplier * InstrCost);
+                (NumCaseCluster - !DefaultDestUndefined) *
+                    CaseClusterCostMultiplier * InstrCost);
       return;
     }
 

>From 2c29887197494a4b6364456d667e03f81973284e Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Thu, 14 Mar 2024 18:35:42 +0800
Subject: [PATCH 3/3] [Inline] Estimate the number of comparisons for the
 unreachable default branch

---
 llvm/lib/Analysis/InlineCost.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index 4b505d3a832cc5..4f6e86b0e58b4e 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -536,7 +536,10 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
 // Considering comparisons from leaf and non-leaf nodes, we can estimate the
 // number of comparisons in a simple closed form :
 //   n + n / 2 - 1 = n * 3 / 2 - 1
-int64_t getExpectedNumberOfCompare(int NumCaseCluster) {
+int64_t getExpectedNumberOfCompare(int NumCaseCluster,
+                                   bool DefaultDestUndefined) {
+  if (DefaultDestUndefined)
+    return static_cast<int64_t>(NumCaseCluster) - 1;
   return 3 * static_cast<int64_t>(NumCaseCluster) / 2 - 1;
 }
 
@@ -724,9 +727,8 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
       return;
     }
 
-    // FIXME: Consider the case when default branch is undefined.
     int64_t ExpectedNumberOfCompare =
-        getExpectedNumberOfCompare(NumCaseCluster);
+        getExpectedNumberOfCompare(NumCaseCluster, DefaultDestUndefined);
     int64_t SwitchCost = ExpectedNumberOfCompare * 2 * InstrCost;
 
     addCost(SwitchCost);
@@ -1258,9 +1260,8 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
       return;
     }
 
-    // FIXME: Consider the case when default branch is undefined.
     int64_t ExpectedNumberOfCompare =
-        getExpectedNumberOfCompare(NumCaseCluster);
+        getExpectedNumberOfCompare(NumCaseCluster, DefaultDestUndefined);
 
     int64_t SwitchCost =
         ExpectedNumberOfCompare * SwitchCostMultiplier * InstrCost;



More information about the llvm-commits mailing list