[llvm] [InlineCost] Correct the default branch cost for the switch statement (PR #85160)

via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 17 05:27:19 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Quentin Dian (DianQK)

<details>
<summary>Changes</summary>

I use the following patch to find functions that are not inlined after #<!-- -->77856.

<details><summary>patch.diff</summary>
<p>

```diff
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index e55eaa55f8e9..e325d18ab0a8 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -29,6 +29,7 @@
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Config/llvm-config.h"
+#include "llvm/Demangle/Demangle.h"
 #include "llvm/IR/AssemblyAnnotationWriter.h"
 #include "llvm/IR/CallingConv.h"
 #include "llvm/IR/DataLayout.h"
@@ -575,6 +576,8 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
   // True if the cost-benefit-analysis-based inliner is enabled.
   const bool CostBenefitAnalysisEnabled;
 
+  int DefaultBranchCost = 0;
+
   /// Inlining cost measured in abstract units, accounts for all the
   /// instructions expected to be executed for a given function invocation.
   /// Instructions that are statically proven to be dead based on call-site
@@ -701,8 +704,11 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
 
   void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
                         bool DefaultDestUndefined) override {
-    if (!DefaultDestUndefined)
+    if (!DefaultDestUndefined) {
+      DefaultBranchCost = std::clamp<int64_t>(DefaultBranchCost + 2 * InstrCost,
+                                              INT_MIN, INT_MAX);
       addCost(2 * InstrCost);
+    }
     // If suitable for a jump table, consider the cost for the table size and
     // branch to destination.
     // Maximum valid cost increased in this function.
@@ -1132,6 +1138,7 @@ public:
   virtual ~InlineCostCallAnalyzer() = default;
   int getThreshold() const { return Threshold; }
   int getCost() const { return Cost; }
+  int getDefaultBranchCost() const { return DefaultBranchCost; }
   int getStaticBonusApplied() const { return StaticBonusApplied; }
   std::optional<CostBenefitPair> getCostBenefitPair() { return CostBenefit; }
   bool wasDecidedByCostBenefit() const { return DecidedByCostBenefit; }
@@ -3072,6 +3079,23 @@ InlineCost llvm::getInlineCost(
                             GetAssumptionCache, GetBFI, PSI, ORE);
   InlineResult ShouldInline = CA.analyze();
 
+  if (CA.getCost() > CA.getThreshold() &&
+      (CA.getCost() - CA.getDefaultBranchCost() <= CA.getThreshold())) {
+    auto ModuleName = Callee->getParent()->getName();
+    auto *CallerName = llvm::itaniumDemangle(Call.getCaller()->getName());
+    auto *CalleeName = llvm::itaniumDemangle(Callee->getName());
+    errs() << "NOT Inlining ModuleName: " << ModuleName << " Caller: " << CallerName
+           << ", Callee: " << CalleeName << ", Cost: " << CA.getCost()
+           << ", Threshold: " << CA.getThreshold()
+           << ", DefaultBranchCost: " << CA.getDefaultBranchCost();
+    if (auto *SP = Callee->getSubprogram()) {
+       auto FileName = SP->getFilename();
+       unsigned Line = SP->getLine();
+       errs() << ", FileName: " << FileName << "#L" << Line;
+    }
+    errs() << "\n";
+  }
+
   LLVM_DEBUG(CA.dump());
 
   // Always make cost benefit based decision explicit.
```

</p>
</details> 

There are over 20,000 call sites that don't satisfy the inline condition. I tried to select 10 of them:

<details><summary>Details</summary>
<p>

- Callee: `DoLowering(llvm::Function&, llvm::GCStrategy&)` FileName: https://github.com/llvm/llvm-project/blob/5aec9392674572fa5a06283173a6a739742d261d/llvm/lib/CodeGen/GCRootLowering.cpp#L201
- Callee: `getFromRangeMetadata(llvm::Instruction*)` FileName: https://github.com/llvm/llvm-project/blob/5aec9392674572fa5a06283173a6a739742d261d/llvm/lib/Analysis/LazyValueInfo.cpp#L589
- Callee: `clang::comments::DeclInfo::fill()` FileName: https://github.com/llvm/llvm-project/blob/5aec9392674572fa5a06283173a6a739742d261d/clang/lib/AST/Comment.cpp#L203
- Callee: `clang::APValue::DestroyDataAndMakeUninit()` FileName: https://github.com/llvm/llvm-project/blob/5aec9392674572fa5a06283173a6a739742d261d/clang/lib/AST/APValue.cpp#L403
- Callee: `clang::targets::MipsTargetInfo::getISARev() const` FileName: https://github.com/llvm/llvm-project/blob/5aec9392674572fa5a06283173a6a739742d261d/clang/lib/Basic/Targets/Mips.cpp#L61
- Callee: `llvm::isLegalUTF8(unsigned char const*, int)` FileName: https://github.com/llvm/llvm-project/blob/5aec9392674572fa5a06283173a6a739742d261d/llvm/lib/Support/ConvertUTF.cpp#L397
- Callee: `llvm::yaml::Input::createHNodes(llvm::yaml::Node*)` FileName: https://github.com/llvm/llvm-project/blob/5aec9392674572fa5a06283173a6a739742d261d/llvm/lib/Support/YAMLTraits.cpp#L401
- Callee: `clang::Parser::ParseOpenACCDirective()` FileName: https://github.com/llvm/llvm-project/blob/5aec9392674572fa5a06283173a6a739742d261d/clang/lib/Parse/ParseOpenACC.cpp#L1119
- Callee: `clang::OMPClauseReader::readClause()` FileName: https://github.com/llvm/llvm-project/blob/5aec9392674572fa5a06283173a6a739742d261d/clang/lib/Serialization/ASTReader.cpp#L10263
- Callee: `clang::CodeGen::CodeGenFunction::EmitLandingPad()` FileName: https://github.com/llvm/llvm-project/blob/5aec9392674572fa5a06283173a6a739742d261d/clang/lib/CodeGen/CGException.cpp#L825

</p>
</details> 

There are complex switch statements that cannot be transformed to simpler structures.

The earliest commit of the related code is: https://github.com/llvm/llvm-project/commit/919f9e8d65ada6552b8b8a5ec12ea49db91c922a. I tried to understand the following code with https://github.com/llvm/llvm-project/pull/77856#issuecomment-1993499085.

https://github.com/llvm/llvm-project/blob/5932fcc47855fdd209784f38820422d2369b84b2/llvm/lib/Analysis/InlineCost.cpp#L709-L720

I think only scenarios where there is a default branch were considered.

Taking https://llvm.godbolt.org/z/5cno1TnGx as an example, we need additional compare and jump instructions when there is a default branch, otherwise we just need a jump instruction.

```asm
foo: # @<!-- -->foo
  cmp rdi, 6
  ja .LBB0_6
  jmp qword ptr [8*rdi + .LJTI0_0]
...
bar: # @<!-- -->bar
  jmp qword ptr [8*rdi + .LJTI1_0]
...
```

But I don't know why it's `4 * InstrCost` and not `3 * InstrCost`.

Taking https://llvm.godbolt.org/z/MEsf9sno7 as an example, we can reduce a set of compare and jump instructions when the number of branches is small.

```asm
foo: # @<!-- -->foo
  cmp rdi, 4
  je .LBB0_5
  cmp rdi, 2
  je .LBB0_4
  test rdi, rdi
  jne .LBB0_6
...
bar: # @<!-- -->bar
  cmp rdi, 4
  je .LBB1_4
  cmp rdi, 2
  jne .LBB1_2
...
```

Further, I found that for scenarios where there are more branches. The generated compare instructions should be less than the number of branches if the default branch is undefined behavior. There will be fewer compare instructions if there are some common branches.

Revert the result of #<!-- -->77856: https://llvm-compile-time-tracker.com/compare.php?from=f3c5278efa3b783ada9e7a34b751cf4c5b864535&to=58622ef6755a02f97e5127bea29ed5b8812fe25e&stat=instructions:u.
New change: https://llvm-compile-time-tracker.com/compare.php?from=58622ef6755a02f97e5127bea29ed5b8812fe25e&to=dc2c2faa82d3d7b998680267a79895eb4969e6fd&stat=instructions%3Au.

---
Full diff: https://github.com/llvm/llvm-project/pull/85160.diff


3 Files Affected:

- (modified) llvm/lib/Analysis/InlineCost.cpp (+24-13) 
- (modified) llvm/test/Transforms/Inline/inline-switch-default-2.ll (+1-1) 
- (modified) llvm/test/Transforms/Inline/inline-switch-default.ll (+2-2) 


``````````diff
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index e55eaa55f8e947..9d29d5765c1915 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -536,7 +536,13 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
 // Considering comparisons from leaf and non-leaf nodes, we can estimate the
 // number of comparisons in a simple closed form :
 //   n + n / 2 - 1 = n * 3 / 2 - 1
-int64_t getExpectedNumberOfCompare(int NumCaseCluster) {
+int64_t getExpectedNumberOfCompare(int NumCaseCluster,
+                                   bool DefaultDestUndefined) {
+  // The compare instruction count should be less than the branch count
+  // when default branch is undefined.
+  if (DefaultDestUndefined) {
+    return static_cast<int64_t>(NumCaseCluster) - 1;
+  }
   return 3 * static_cast<int64_t>(NumCaseCluster) / 2 - 1;
 }
 
@@ -701,26 +707,31 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
 
   void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
                         bool DefaultDestUndefined) override {
-    if (!DefaultDestUndefined)
-      addCost(2 * InstrCost);
     // If suitable for a jump table, consider the cost for the table size and
     // branch to destination.
     // Maximum valid cost increased in this function.
     if (JumpTableSize) {
+      // Suppose a default branch includes one compare and one conditional
+      // branch if it's reachable.
+      if (!DefaultDestUndefined)
+        addCost(2 * InstrCost);
+      // The jump table only requires a jump instruction.
       int64_t JTCost =
-          static_cast<int64_t>(JumpTableSize) * InstrCost + 4 * InstrCost;
+          static_cast<int64_t>(JumpTableSize) * InstrCost + InstrCost;
       addCost(JTCost);
       return;
     }
 
     if (NumCaseCluster <= 3) {
       // Suppose a comparison includes one compare and one conditional branch.
-      addCost(NumCaseCluster * 2 * InstrCost);
+      // We can reduce a set of instructions if the default branch is
+      // undefined.
+      addCost((NumCaseCluster - DefaultDestUndefined) * 2 * InstrCost);
       return;
     }
 
     int64_t ExpectedNumberOfCompare =
-        getExpectedNumberOfCompare(NumCaseCluster);
+        getExpectedNumberOfCompare(NumCaseCluster, DefaultDestUndefined);
     int64_t SwitchCost = ExpectedNumberOfCompare * 2 * InstrCost;
 
     addCost(SwitchCost);
@@ -1152,7 +1163,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
   // FIXME: These constants are taken from the heuristic-based cost visitor.
   // These should be removed entirely in a later revision to avoid reliance on
   // heuristics in the ML inliner.
-  static constexpr int JTCostMultiplier = 4;
+  static constexpr int JTCostMultiplier = 1;
   static constexpr int CaseClusterCostMultiplier = 2;
   static constexpr int SwitchDefaultDestCostMultiplier = 2;
   static constexpr int SwitchCostMultiplier = 2;
@@ -1235,11 +1246,10 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
 
   void onFinalizeSwitch(unsigned JumpTableSize, unsigned NumCaseCluster,
                         bool DefaultDestUndefined) override {
-    if (!DefaultDestUndefined)
-      increment(InlineCostFeatureIndex::switch_default_dest_penalty,
-                SwitchDefaultDestCostMultiplier * InstrCost);
-
     if (JumpTableSize) {
+      if (!DefaultDestUndefined)
+        increment(InlineCostFeatureIndex::switch_default_dest_penalty,
+                  SwitchDefaultDestCostMultiplier * InstrCost);
       int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost +
                        JTCostMultiplier * InstrCost;
       increment(InlineCostFeatureIndex::jump_table_penalty, JTCost);
@@ -1248,12 +1258,13 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
 
     if (NumCaseCluster <= 3) {
       increment(InlineCostFeatureIndex::case_cluster_penalty,
-                NumCaseCluster * CaseClusterCostMultiplier * InstrCost);
+                (NumCaseCluster - DefaultDestUndefined) *
+                    CaseClusterCostMultiplier * InstrCost);
       return;
     }
 
     int64_t ExpectedNumberOfCompare =
-        getExpectedNumberOfCompare(NumCaseCluster);
+        getExpectedNumberOfCompare(NumCaseCluster, DefaultDestUndefined);
 
     int64_t SwitchCost =
         ExpectedNumberOfCompare * SwitchCostMultiplier * InstrCost;
diff --git a/llvm/test/Transforms/Inline/inline-switch-default-2.ll b/llvm/test/Transforms/Inline/inline-switch-default-2.ll
index 8d3e24c798df82..1a648300ae3c1e 100644
--- a/llvm/test/Transforms/Inline/inline-switch-default-2.ll
+++ b/llvm/test/Transforms/Inline/inline-switch-default-2.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt %s -S -passes=inline -inline-threshold=21 | FileCheck %s
+; RUN: opt %s -S -passes=inline -inline-threshold=11 | FileCheck %s
 
 ; Check for scenarios without TTI.
 
diff --git a/llvm/test/Transforms/Inline/inline-switch-default.ll b/llvm/test/Transforms/Inline/inline-switch-default.ll
index 44f1304e82dff0..6a50820aad3a7d 100644
--- a/llvm/test/Transforms/Inline/inline-switch-default.ll
+++ b/llvm/test/Transforms/Inline/inline-switch-default.ll
@@ -1,6 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt %s -S -passes=inline -inline-threshold=26 -min-jump-table-entries=4 | FileCheck %s -check-prefix=LOOKUPTABLE
-; RUN: opt %s -S -passes=inline -inline-threshold=21 -min-jump-table-entries=5 | FileCheck %s -check-prefix=SWITCH
+; RUN: opt %s -S -passes=inline -inline-threshold=16 -min-jump-table-entries=4 | FileCheck %s -check-prefix=LOOKUPTABLE
+; RUN: opt %s -S -passes=inline -inline-threshold=11 -min-jump-table-entries=5 | FileCheck %s -check-prefix=SWITCH
 
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"

``````````

</details>


https://github.com/llvm/llvm-project/pull/85160


More information about the llvm-commits mailing list