[llvm] [WIP][Inline] Rewrite switch's handling of the default branch (PR #85160)
Quentin Dian via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 13 17:25:43 PDT 2024
https://github.com/DianQK created https://github.com/llvm/llvm-project/pull/85160
This PR fixes the instruction cost calculation when the default branch is undefined behavior.
>From b74b7b125c3243a2113b57fa1b39017918b804c8 Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Thu, 22 Feb 2024 20:31:44 +0800
Subject: [PATCH 1/3] [Inline] Consider the default branch when transforming to
comparison
---
llvm/lib/Analysis/InlineCost.cpp | 25 +++++++++++++++----------
1 file changed, 15 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index e55eaa55f8e947..e26fcfe8876f22 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -701,24 +701,29 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
bool DefaultDestUndefined) override {
- if (!DefaultDestUndefined)
- addCost(2 * InstrCost);
// If suitable for a jump table, consider the cost for the table size and
// branch to destination.
// Maximum valid cost increased in this function.
if (JumpTableSize) {
+ // Suppose a default branch includes one compare and one conditional
+ // branch if it's reachable.
+ if (!DefaultDestUndefined)
+ addCost(2 * InstrCost);
int64_t JTCost =
static_cast<int64_t>(JumpTableSize) * InstrCost + 4 * InstrCost;
addCost(JTCost);
return;
}
- if (NumCaseCluster <= 3) {
+ if ((NumCaseCluster + !DefaultDestUndefined) <= 4) {
// Suppose a comparison includes one compare and one conditional branch.
- addCost(NumCaseCluster * 2 * InstrCost);
+ // We can create one less set of instructions if the default branch is
+ // undefined.
+ addCost((NumCaseCluster - DefaultDestUndefined) * 2 * InstrCost);
return;
}
+ // FIXME: Consider the case when default branch is undefined.
int64_t ExpectedNumberOfCompare =
getExpectedNumberOfCompare(NumCaseCluster);
int64_t SwitchCost = ExpectedNumberOfCompare * 2 * InstrCost;
@@ -1235,23 +1240,23 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
bool DefaultDestUndefined) override {
- if (!DefaultDestUndefined)
- increment(InlineCostFeatureIndex::switch_default_dest_penalty,
- SwitchDefaultDestCostMultiplier * InstrCost);
-
if (JumpTableSize) {
+ if (!DefaultDestUndefined)
+ increment(InlineCostFeatureIndex::switch_default_dest_penalty,
+ SwitchDefaultDestCostMultiplier * InstrCost);
int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost +
JTCostMultiplier * InstrCost;
increment(InlineCostFeatureIndex::jump_table_penalty, JTCost);
return;
}
- if (NumCaseCluster <= 3) {
+ if ((NumCaseCluster + !DefaultDestUndefined) <= 4) {
increment(InlineCostFeatureIndex::case_cluster_penalty,
- NumCaseCluster * CaseClusterCostMultiplier * InstrCost);
+ (NumCaseCluster - !DefaultDestUndefined) * CaseClusterCostMultiplier * InstrCost);
return;
}
+ // FIXME: Consider the case when default branch is undefined.
int64_t ExpectedNumberOfCompare =
getExpectedNumberOfCompare(NumCaseCluster);
>From 0a525581096c889f5122f4008d2fe8eef9af071e Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Wed, 13 Mar 2024 22:04:22 +0800
Subject: [PATCH 2/3] [Inline] The jump table only requires a jump instruction
---
llvm/lib/Analysis/InlineCost.cpp | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index e26fcfe8876f22..4b505d3a832cc5 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -709,8 +709,9 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
// branch if it's reachable.
if (!DefaultDestUndefined)
addCost(2 * InstrCost);
+ // The jump table only requires a jump instruction.
int64_t JTCost =
- static_cast<int64_t>(JumpTableSize) * InstrCost + 4 * InstrCost;
+ static_cast<int64_t>(JumpTableSize) * InstrCost + InstrCost;
addCost(JTCost);
return;
}
@@ -1157,7 +1158,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
// FIXME: These constants are taken from the heuristic-based cost visitor.
// These should be removed entirely in a later revision to avoid reliance on
// heuristics in the ML inliner.
- static constexpr int JTCostMultiplier = 4;
+ static constexpr int JTCostMultiplier = 1;
static constexpr int CaseClusterCostMultiplier = 2;
static constexpr int SwitchDefaultDestCostMultiplier = 2;
static constexpr int SwitchCostMultiplier = 2;
@@ -1243,7 +1244,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
if (JumpTableSize) {
if (!DefaultDestUndefined)
increment(InlineCostFeatureIndex::switch_default_dest_penalty,
- SwitchDefaultDestCostMultiplier * InstrCost);
+ SwitchDefaultDestCostMultiplier * InstrCost);
int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost +
JTCostMultiplier * InstrCost;
increment(InlineCostFeatureIndex::jump_table_penalty, JTCost);
@@ -1252,7 +1253,8 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
if ((NumCaseCluster + !DefaultDestUndefined) <= 4) {
increment(InlineCostFeatureIndex::case_cluster_penalty,
- (NumCaseCluster - !DefaultDestUndefined) * CaseClusterCostMultiplier * InstrCost);
+ (NumCaseCluster - !DefaultDestUndefined) *
+ CaseClusterCostMultiplier * InstrCost);
return;
}
>From 813c9b04d4c88936eae1650cdc7494fc057cd117 Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Wed, 13 Mar 2024 22:48:53 +0800
Subject: [PATCH 3/3] [Inline] Ignore jump table size
---
llvm/lib/Analysis/InlineCost.cpp | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index 4b505d3a832cc5..3ed7b85afb9059 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -710,8 +710,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
if (!DefaultDestUndefined)
addCost(2 * InstrCost);
// The jump table only requires a jump instruction.
- int64_t JTCost =
- static_cast<int64_t>(JumpTableSize) * InstrCost + InstrCost;
+ int64_t JTCost = InstrCost;
addCost(JTCost);
return;
}
@@ -1245,8 +1244,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
if (!DefaultDestUndefined)
increment(InlineCostFeatureIndex::switch_default_dest_penalty,
SwitchDefaultDestCostMultiplier * InstrCost);
- int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost +
- JTCostMultiplier * InstrCost;
+ int64_t JTCost = JTCostMultiplier * InstrCost;
increment(InlineCostFeatureIndex::jump_table_penalty, JTCost);
return;
}
More information about the llvm-commits
mailing list