[llvm] [InlineCost] Correct the default branch cost for the switch statement (PR #85160)

Quentin Dian via llvm-commits llvm-commits at lists.llvm.org
Thu May 2 18:58:54 PDT 2024


https://github.com/DianQK updated https://github.com/llvm/llvm-project/pull/85160

>From 98b32349eb3fdda673a415681524ae3e288ecd26 Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Sun, 17 Mar 2024 16:17:24 +0800
Subject: [PATCH] [InlineCost] Correct the default branch cost for the switch
 statement

---
 llvm/lib/Analysis/InlineCost.cpp              |  25 ++--
 .../Inline/inline-cost-switch-default.ll      | 130 ++++++++++++++++++
 .../Inline/inline-switch-default-2.ll         |  21 +--
 .../Inline/inline-switch-default.ll           |  25 +---
 4 files changed, 152 insertions(+), 49 deletions(-)
 create mode 100644 llvm/test/Transforms/Inline/inline-cost-switch-default.ll

diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index c75460f44c1d9f..e4989db816d009 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -701,21 +701,26 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
 
   void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
                         bool DefaultDestUndefined) override {
-    if (!DefaultDestUndefined)
-      addCost(2 * InstrCost);
     // If suitable for a jump table, consider the cost for the table size and
     // branch to destination.
     // Maximum valid cost increased in this function.
     if (JumpTableSize) {
+      // Suppose a default branch includes one compare and one conditional
+      // branch if it's reachable.
+      if (!DefaultDestUndefined)
+        addCost(2 * InstrCost);
+      // Suppose a jump table requires one load and one jump instruction.
       int64_t JTCost =
-          static_cast<int64_t>(JumpTableSize) * InstrCost + 4 * InstrCost;
+          static_cast<int64_t>(JumpTableSize) * InstrCost + 2 * InstrCost;
       addCost(JTCost);
       return;
     }
 
     if (NumCaseCluster <= 3) {
       // Suppose a comparison includes one compare and one conditional branch.
-      addCost(NumCaseCluster * 2 * InstrCost);
+      // We can reduce a set of instructions if the default branch is
+      // undefined.
+      addCost((NumCaseCluster - DefaultDestUndefined) * 2 * InstrCost);
       return;
     }
 
@@ -1152,7 +1157,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
   // FIXME: These constants are taken from the heuristic-based cost visitor.
   // These should be removed entirely in a later revision to avoid reliance on
   // heuristics in the ML inliner.
-  static constexpr int JTCostMultiplier = 4;
+  static constexpr int JTCostMultiplier = 2;
   static constexpr int CaseClusterCostMultiplier = 2;
   static constexpr int SwitchDefaultDestCostMultiplier = 2;
   static constexpr int SwitchCostMultiplier = 2;
@@ -1235,11 +1240,10 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
 
   void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
                         bool DefaultDestUndefined) override {
-    if (!DefaultDestUndefined)
-      increment(InlineCostFeatureIndex::switch_default_dest_penalty,
-                SwitchDefaultDestCostMultiplier * InstrCost);
-
     if (JumpTableSize) {
+      if (!DefaultDestUndefined)
+        increment(InlineCostFeatureIndex::switch_default_dest_penalty,
+                  SwitchDefaultDestCostMultiplier * InstrCost);
       int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost +
                        JTCostMultiplier * InstrCost;
       increment(InlineCostFeatureIndex::jump_table_penalty, JTCost);
@@ -1248,7 +1252,8 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
 
     if (NumCaseCluster <= 3) {
       increment(InlineCostFeatureIndex::case_cluster_penalty,
-                NumCaseCluster * CaseClusterCostMultiplier * InstrCost);
+                (NumCaseCluster - DefaultDestUndefined) *
+                    CaseClusterCostMultiplier * InstrCost);
       return;
     }
 
diff --git a/llvm/test/Transforms/Inline/inline-cost-switch-default.ll b/llvm/test/Transforms/Inline/inline-cost-switch-default.ll
new file mode 100644
index 00000000000000..3710d560521fb3
--- /dev/null
+++ b/llvm/test/Transforms/Inline/inline-cost-switch-default.ll
@@ -0,0 +1,130 @@
+; RUN: opt -S -passes=inline %s -debug-only=inline-cost -min-jump-table-entries=4 --disable-output 2>&1 | FileCheck %s -check-prefix=LOOKUPTABLE -match-full-lines
+; RUN: opt -S -passes=inline %s -debug-only=inline-cost -min-jump-table-entries=5 --disable-output 2>&1 | FileCheck %s -check-prefix=SWITCH -match-full-lines
+; REQUIRES: asserts
+
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+define i64 @main(i64 %a) {
+  %b = call i64 @small_switch_default(i64 %a)
+  %c = call i64 @small_switch_no_default(i64 %a)
+  %d = call i64 @lookup_table_default(i64 %a)
+  %e = call i64 @lookup_table_no_default(i64 %a)
+  ret i64 %b
+}
+
+; SWITCH-LABEL: Analyzing call of small_switch_default{{.*}}
+; SWITCH: Cost: 0
+define i64 @small_switch_default(i64 %a) {
+  switch i64 %a, label %default_branch [
+  i64 -1, label %branch_0
+  i64 8, label %branch_1
+  i64 52, label %branch_2
+  ]
+
+branch_0:
+  br label %exit
+
+branch_1:
+  br label %exit
+
+branch_2:
+  br label %exit
+
+default_branch:
+  br label %exit
+
+exit:
+  %b = phi i64 [ 5, %branch_0 ], [ 9, %branch_1 ], [ 2, %branch_2 ], [ 3, %default_branch ]
+  ret i64 %b
+}
+
+; SWITCH-LABEL: Analyzing call of small_switch_no_default{{.*}}
+; SWITCH: Cost: -10
+define i64 @small_switch_no_default(i64 %a) {
+  switch i64 %a, label %unreachabledefault [
+  i64 -1, label %branch_0
+  i64 8, label %branch_1
+  i64 52, label %branch_2
+  ]
+
+branch_0:
+  br label %exit
+
+branch_1:
+  br label %exit
+
+branch_2:
+  br label %exit
+
+unreachabledefault:
+  unreachable
+
+exit:
+  %b = phi i64 [ 5, %branch_0 ], [ 9, %branch_1 ], [ 2, %branch_2 ]
+  ret i64 %b
+}
+
+; LOOKUPTABLE-LABEL: Analyzing call of lookup_table_default{{.*}}
+; LOOKUPTABLE: Cost: 10
+; SWITCH-LABEL: Analyzing call of lookup_table_default{{.*}}
+; SWITCH: Cost: 20
+define i64 @lookup_table_default(i64 %a) {
+  switch i64 %a, label %default_branch [
+  i64 0, label %branch_0
+  i64 1, label %branch_1
+  i64 2, label %branch_2
+  i64 3, label %branch_3
+  ]
+
+branch_0:
+  br label %exit
+
+branch_1:
+  br label %exit
+
+branch_2:
+  br label %exit
+
+branch_3:
+  br label %exit
+
+default_branch:
+  br label %exit
+
+exit:
+  %b = phi i64 [ 5, %branch_0 ], [ 9, %branch_1 ], [ 2, %branch_2 ], [ 7, %branch_3 ], [ 3, %default_branch ]
+  ret i64 %b
+}
+
+; LOOKUPTABLE-LABEL: Analyzing call of lookup_table_no_default{{.*}}
+; LOOKUPTABLE: Cost: 0
+; SWITCH-LABEL: Analyzing call of lookup_table_no_default{{.*}}
+; SWITCH: Cost: 20
+define i64 @lookup_table_no_default(i64 %a) {
+  switch i64 %a, label %unreachabledefault [
+  i64 0, label %branch_0
+  i64 1, label %branch_1
+  i64 2, label %branch_2
+  i64 3, label %branch_3
+  ]
+
+branch_0:
+  br label %exit
+
+branch_1:
+  br label %exit
+
+branch_2:
+  br label %exit
+
+branch_3:
+  br label %exit
+
+unreachabledefault:
+  unreachable
+
+exit:
+  %b = phi i64 [ 5, %branch_0 ], [ 9, %branch_1 ], [ 2, %branch_2 ], [ 7, %branch_3 ]
+  ret i64 %b
+}
diff --git a/llvm/test/Transforms/Inline/inline-switch-default-2.ll b/llvm/test/Transforms/Inline/inline-switch-default-2.ll
index 82dae1c27648fc..169cb2cff9b82c 100644
--- a/llvm/test/Transforms/Inline/inline-switch-default-2.ll
+++ b/llvm/test/Transforms/Inline/inline-switch-default-2.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt %s -S -passes=inline -inline-threshold=21 | FileCheck %s
+; RUN: opt %s -S -passes=inline -inline-threshold=11 | FileCheck %s
 
 ; Check for scenarios without TTI.
 
@@ -16,24 +16,7 @@ define i64 @foo1(i64 %a) {
 define i64 @foo2(i64 %a) {
 ; CHECK-LABEL: define i64 @foo2(
 ; CHECK-SAME: i64 [[A:%.*]]) {
-; CHECK-NEXT:    switch i64 [[A]], label [[UNREACHABLEDEFAULT_I:%.*]] [
-; CHECK-NEXT:      i64 0, label [[BRANCH_0_I:%.*]]
-; CHECK-NEXT:      i64 2, label [[BRANCH_2_I:%.*]]
-; CHECK-NEXT:      i64 4, label [[BRANCH_4_I:%.*]]
-; CHECK-NEXT:      i64 6, label [[BRANCH_6_I:%.*]]
-; CHECK-NEXT:    ]
-; CHECK:       branch_0.i:
-; CHECK-NEXT:    br label [[BAR2_EXIT:%.*]]
-; CHECK:       branch_2.i:
-; CHECK-NEXT:    br label [[BAR2_EXIT]]
-; CHECK:       branch_4.i:
-; CHECK-NEXT:    br label [[BAR2_EXIT]]
-; CHECK:       branch_6.i:
-; CHECK-NEXT:    br label [[BAR2_EXIT]]
-; CHECK:       unreachabledefault.i:
-; CHECK-NEXT:    unreachable
-; CHECK:       bar2.exit:
-; CHECK-NEXT:    [[B_I:%.*]] = phi i64 [ 5, [[BRANCH_0_I]] ], [ 9, [[BRANCH_2_I]] ], [ 2, [[BRANCH_4_I]] ], [ 7, [[BRANCH_6_I]] ]
+; CHECK-NEXT:    [[B_I:%.*]] = call i64 @bar2(i64 [[A]])
 ; CHECK-NEXT:    ret i64 [[B_I]]
 ;
   %b = call i64 @bar2(i64 %a)
diff --git a/llvm/test/Transforms/Inline/inline-switch-default.ll b/llvm/test/Transforms/Inline/inline-switch-default.ll
index 44f1304e82dff0..48789667865b44 100644
--- a/llvm/test/Transforms/Inline/inline-switch-default.ll
+++ b/llvm/test/Transforms/Inline/inline-switch-default.ll
@@ -1,6 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt %s -S -passes=inline -inline-threshold=26 -min-jump-table-entries=4 | FileCheck %s -check-prefix=LOOKUPTABLE
-; RUN: opt %s -S -passes=inline -inline-threshold=21 -min-jump-table-entries=5 | FileCheck %s -check-prefix=SWITCH
+; RUN: opt %s -S -passes=inline -inline-threshold=16 -min-jump-table-entries=4 | FileCheck %s -check-prefix=LOOKUPTABLE
+; RUN: opt %s -S -passes=inline -inline-threshold=11 -min-jump-table-entries=5 | FileCheck %s -check-prefix=SWITCH
 
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"
@@ -22,6 +22,8 @@ define i64 @foo1(i64 %a) {
   ret i64 %b
 }
 
+; Since the default branch is undefined behavior,
+; we can inline `bar2`: https://github.com/llvm/llvm-project/issues/90929
 define i64 @foo2(i64 %a) {
 ; LOOKUPTABLE-LABEL: define i64 @foo2(
 ; LOOKUPTABLE-SAME: i64 [[A:%.*]]) {
@@ -47,24 +49,7 @@ define i64 @foo2(i64 %a) {
 ;
 ; SWITCH-LABEL: define i64 @foo2(
 ; SWITCH-SAME: i64 [[A:%.*]]) {
-; SWITCH-NEXT:    switch i64 [[A]], label [[UNREACHABLEDEFAULT_I:%.*]] [
-; SWITCH-NEXT:      i64 0, label [[BRANCH_0_I:%.*]]
-; SWITCH-NEXT:      i64 2, label [[BRANCH_2_I:%.*]]
-; SWITCH-NEXT:      i64 4, label [[BRANCH_4_I:%.*]]
-; SWITCH-NEXT:      i64 6, label [[BRANCH_6_I:%.*]]
-; SWITCH-NEXT:    ]
-; SWITCH:       branch_0.i:
-; SWITCH-NEXT:    br label [[BAR2_EXIT:%.*]]
-; SWITCH:       branch_2.i:
-; SWITCH-NEXT:    br label [[BAR2_EXIT]]
-; SWITCH:       branch_4.i:
-; SWITCH-NEXT:    br label [[BAR2_EXIT]]
-; SWITCH:       branch_6.i:
-; SWITCH-NEXT:    br label [[BAR2_EXIT]]
-; SWITCH:       unreachabledefault.i:
-; SWITCH-NEXT:    unreachable
-; SWITCH:       bar2.exit:
-; SWITCH-NEXT:    [[B_I:%.*]] = phi i64 [ 5, [[BRANCH_0_I]] ], [ 9, [[BRANCH_2_I]] ], [ 2, [[BRANCH_4_I]] ], [ 7, [[BRANCH_6_I]] ]
+; SWITCH-NEXT:    [[B_I:%.*]] = call i64 @bar2(i64 [[A]])
 ; SWITCH-NEXT:    ret i64 [[B_I]]
 ;
   %b = call i64 @bar2(i64 %a)



More information about the llvm-commits mailing list