[llvm] [WIP][Inline] Rewrite switch's handling of the default branch (PR #85160)

Quentin Dian via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 17 03:05:22 PDT 2024


https://github.com/DianQK updated https://github.com/llvm/llvm-project/pull/85160

>From 58622ef6755a02f97e5127bea29ed5b8812fe25e Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Sun, 17 Mar 2024 16:17:24 +0800
Subject: [PATCH 1/6] [InlineCost] Ignore default branch

---
 llvm/lib/Analysis/InlineCost.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index e55eaa55f8e947..8b495207fccc51 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -701,8 +701,8 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
 
   void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
                         bool DefaultDestUndefined) override {
-    if (!DefaultDestUndefined)
-      addCost(2 * InstrCost);
+    // if (!DefaultDestUndefined)
+    //   addCost(2 * InstrCost);
     // If suitable for a jump table, consider the cost for the table size and
     // branch to destination.
     // Maximum valid cost increased in this function.
@@ -1235,9 +1235,9 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
 
   void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
                         bool DefaultDestUndefined) override {
-    if (!DefaultDestUndefined)
-      increment(InlineCostFeatureIndex::switch_default_dest_penalty,
-                SwitchDefaultDestCostMultiplier * InstrCost);
+    // if (!DefaultDestUndefined)
+    //   increment(InlineCostFeatureIndex::switch_default_dest_penalty,
+    //             SwitchDefaultDestCostMultiplier * InstrCost);
 
     if (JumpTableSize) {
       int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost +

>From d69f4bc19bc04ed0b3c056938b711ea7d7712ef0 Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Thu, 22 Feb 2024 20:31:44 +0800
Subject: [PATCH 2/6] [InlineCost] Consider the default branch when
 transforming to comparison

---
 llvm/lib/Analysis/InlineCost.cpp | 22 ++++++++++++++--------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index 8b495207fccc51..ffffa15f3dd730 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -701,12 +701,14 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
 
   void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
                         bool DefaultDestUndefined) override {
-    // if (!DefaultDestUndefined)
-    //   addCost(2 * InstrCost);
     // If suitable for a jump table, consider the cost for the table size and
     // branch to destination.
     // Maximum valid cost increased in this function.
     if (JumpTableSize) {
+      // Suppose a default branch includes one compare and one conditional
+      // branch if it's reachable.
+      if (!DefaultDestUndefined)
+        addCost(2 * InstrCost);
       int64_t JTCost =
           static_cast<int64_t>(JumpTableSize) * InstrCost + 4 * InstrCost;
       addCost(JTCost);
@@ -715,10 +717,13 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
 
     if (NumCaseCluster <= 3) {
       // Suppose a comparison includes one compare and one conditional branch.
-      addCost(NumCaseCluster * 2 * InstrCost);
+      // We can reduce a set of instructions if the default branch is
+      // undefined.
+      addCost((NumCaseCluster - DefaultDestUndefined) * 2 * InstrCost);
       return;
     }
 
+    // FIXME: Consider the case when default branch is undefined.
     int64_t ExpectedNumberOfCompare =
         getExpectedNumberOfCompare(NumCaseCluster);
     int64_t SwitchCost = ExpectedNumberOfCompare * 2 * InstrCost;
@@ -1235,11 +1240,10 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
 
   void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
                         bool DefaultDestUndefined) override {
-    // if (!DefaultDestUndefined)
-    //   increment(InlineCostFeatureIndex::switch_default_dest_penalty,
-    //             SwitchDefaultDestCostMultiplier * InstrCost);
-
     if (JumpTableSize) {
+      if (!DefaultDestUndefined)
+        increment(InlineCostFeatureIndex::switch_default_dest_penalty,
+                  SwitchDefaultDestCostMultiplier * InstrCost);
       int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost +
                        JTCostMultiplier * InstrCost;
       increment(InlineCostFeatureIndex::jump_table_penalty, JTCost);
@@ -1248,10 +1252,12 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
 
     if (NumCaseCluster <= 3) {
       increment(InlineCostFeatureIndex::case_cluster_penalty,
-                NumCaseCluster * CaseClusterCostMultiplier * InstrCost);
+                (NumCaseCluster - DefaultDestUndefined) *
+                    CaseClusterCostMultiplier * InstrCost);
       return;
     }
 
+    // FIXME: Consider the case when default branch is undefined.
     int64_t ExpectedNumberOfCompare =
         getExpectedNumberOfCompare(NumCaseCluster);
 

>From 9333f9171aedd73a15600a542b72c88fa8bd00d6 Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Wed, 13 Mar 2024 22:04:22 +0800
Subject: [PATCH 3/6] [InlineCost] The jump table only requires a jump
 instruction

---
 llvm/lib/Analysis/InlineCost.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index ffffa15f3dd730..de5bad23bc728b 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -709,8 +709,9 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
       // branch if it's reachable.
       if (!DefaultDestUndefined)
         addCost(2 * InstrCost);
+      // The jump table only requires a jump instruction.
       int64_t JTCost =
-          static_cast<int64_t>(JumpTableSize) * InstrCost + 4 * InstrCost;
+          static_cast<int64_t>(JumpTableSize) * InstrCost + InstrCost;
       addCost(JTCost);
       return;
     }
@@ -1157,7 +1158,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
   // FIXME: These constants are taken from the heuristic-based cost visitor.
   // These should be removed entirely in a later revision to avoid reliance on
   // heuristics in the ML inliner.
-  static constexpr int JTCostMultiplier = 4;
+  static constexpr int JTCostMultiplier = 1;
   static constexpr int CaseClusterCostMultiplier = 2;
   static constexpr int SwitchDefaultDestCostMultiplier = 2;
   static constexpr int SwitchCostMultiplier = 2;

>From 9f29ce41a6cfd512b910ced8ba603db1961e24ec Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Sun, 17 Mar 2024 16:49:53 +0800
Subject: [PATCH 4/6] [InlineCost] Reduce a comparison

---
 llvm/lib/Analysis/InlineCost.cpp | 18 ++++++++++++------
 1 file changed, 12 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index de5bad23bc728b..09611cd1c0ce73 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -536,8 +536,16 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
 // Considering comparisons from leaf and non-leaf nodes, we can estimate the
 // number of comparisons in a simple closed form :
 //   n + n / 2 - 1 = n * 3 / 2 - 1
-int64_t getExpectedNumberOfCompare(int NumCaseCluster) {
-  return 3 * static_cast<int64_t>(NumCaseCluster) / 2 - 1;
+int64_t getExpectedNumberOfCompare(int NumCaseCluster,
+                                   bool DefaultDestUndefined) {
+  int64_t ExpectedNumber = 3 * static_cast<int64_t>(NumCaseCluster) / 2 - 1;
+  // FIXME: The compare instruction count should be less than the branch count
+  // when default branch is undefined. But this will cause some performance
+  // regressions. At least, we can now try to remove a compare instruction.
+  if (DefaultDestUndefined) {
+    ExpectedNumber -= 1;
+  }
+  return ExpectedNumber;
 }
 
 /// FIXME: if it is necessary to derive from InlineCostCallAnalyzer, note
@@ -724,9 +732,8 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
       return;
     }
 
-    // FIXME: Consider the case when default branch is undefined.
     int64_t ExpectedNumberOfCompare =
-        getExpectedNumberOfCompare(NumCaseCluster);
+        getExpectedNumberOfCompare(NumCaseCluster, DefaultDestUndefined);
     int64_t SwitchCost = ExpectedNumberOfCompare * 2 * InstrCost;
 
     addCost(SwitchCost);
@@ -1258,9 +1265,8 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
       return;
     }
 
-    // FIXME: Consider the case when default branch is undefined.
     int64_t ExpectedNumberOfCompare =
-        getExpectedNumberOfCompare(NumCaseCluster);
+        getExpectedNumberOfCompare(NumCaseCluster, DefaultDestUndefined);
 
     int64_t SwitchCost =
         ExpectedNumberOfCompare * SwitchCostMultiplier * InstrCost;

>From 9a7e12ca9ff9f666d3b96915a22123f9a2c623d2 Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Sun, 17 Mar 2024 16:51:38 +0800
Subject: [PATCH 5/6] [InlineCost] Update test cases

---
 llvm/test/Transforms/Inline/inline-switch-default-2.ll | 2 +-
 llvm/test/Transforms/Inline/inline-switch-default.ll   | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/test/Transforms/Inline/inline-switch-default-2.ll b/llvm/test/Transforms/Inline/inline-switch-default-2.ll
index 8d3e24c798df82..1a648300ae3c1e 100644
--- a/llvm/test/Transforms/Inline/inline-switch-default-2.ll
+++ b/llvm/test/Transforms/Inline/inline-switch-default-2.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt %s -S -passes=inline -inline-threshold=21 | FileCheck %s
+; RUN: opt %s -S -passes=inline -inline-threshold=11 | FileCheck %s
 
 ; Check for scenarios without TTI.
 
diff --git a/llvm/test/Transforms/Inline/inline-switch-default.ll b/llvm/test/Transforms/Inline/inline-switch-default.ll
index 44f1304e82dff0..6a50820aad3a7d 100644
--- a/llvm/test/Transforms/Inline/inline-switch-default.ll
+++ b/llvm/test/Transforms/Inline/inline-switch-default.ll
@@ -1,6 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt %s -S -passes=inline -inline-threshold=26 -min-jump-table-entries=4 | FileCheck %s -check-prefix=LOOKUPTABLE
-; RUN: opt %s -S -passes=inline -inline-threshold=21 -min-jump-table-entries=5 | FileCheck %s -check-prefix=SWITCH
+; RUN: opt %s -S -passes=inline -inline-threshold=16 -min-jump-table-entries=4 | FileCheck %s -check-prefix=LOOKUPTABLE
+; RUN: opt %s -S -passes=inline -inline-threshold=11 -min-jump-table-entries=5 | FileCheck %s -check-prefix=SWITCH
 
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"

>From dc2c2faa82d3d7b998680267a79895eb4969e6fd Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Sun, 17 Mar 2024 18:05:01 +0800
Subject: [PATCH 6/6] [perf experiment] Update the number of comparisons for
 the default branch

---
 llvm/lib/Analysis/InlineCost.cpp | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index 09611cd1c0ce73..9d29d5765c1915 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -538,14 +538,12 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
 //   n + n / 2 - 1 = n * 3 / 2 - 1
 int64_t getExpectedNumberOfCompare(int NumCaseCluster,
                                    bool DefaultDestUndefined) {
-  int64_t ExpectedNumber = 3 * static_cast<int64_t>(NumCaseCluster) / 2 - 1;
-  // FIXME: The compare instruction count should be less than the branch count
-  // when default branch is undefined. But this will cause some performance
-  // regressions. At least, we can now try to remove a compare instruction.
+  // The compare instruction count should be less than the branch count
+  // when default branch is undefined.
   if (DefaultDestUndefined) {
-    ExpectedNumber -= 1;
+    return static_cast<int64_t>(NumCaseCluster) - 1;
   }
-  return ExpectedNumber;
+  return 3 * static_cast<int64_t>(NumCaseCluster) / 2 - 1;
 }
 
 /// FIXME: if it is necessary to derive from InlineCostCallAnalyzer, note



More information about the llvm-commits mailing list