[llvm] [AsmPrinter][ELF] Support profile-guided section prefix for jump tables' (read-only) data sections (PR #122215)

Mingming Liu via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 28 13:04:08 PST 2025


https://github.com/mingmingl-llvm updated https://github.com/llvm/llvm-project/pull/122215

>From 5d207e9f4738b2c49e171a4d9270a294de44ecb4 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Wed, 8 Jan 2025 14:28:43 -0800
Subject: [PATCH 01/17] [SDP]Introduce StaticDataSplitter pass and implemenet
 jump table splitting

---
 llvm/include/llvm/CodeGen/MachineBasicBlock.h |   6 +
 llvm/include/llvm/CodeGen/MachineFunction.h   |   9 +
 .../llvm/CodeGen/MachineJumpTableInfo.h       |   8 +-
 llvm/include/llvm/CodeGen/Passes.h            |   4 +
 llvm/include/llvm/InitializePasses.h          |   1 +
 .../llvm/Passes/MachinePassRegistry.def       |   1 +
 llvm/lib/CodeGen/CMakeLists.txt               |   1 +
 llvm/lib/CodeGen/CodeGen.cpp                  |   1 +
 llvm/lib/CodeGen/MachineBasicBlock.cpp        |   4 +
 llvm/lib/CodeGen/MachineFunction.cpp          |  13 ++
 llvm/lib/CodeGen/StaticDataSplitter.cpp       | 153 ++++++++++++++++
 llvm/lib/CodeGen/TargetPassConfig.cpp         |   1 +
 llvm/test/CodeGen/X86/jump-table-partition.ll | 163 ++++++++++++++++++
 13 files changed, 363 insertions(+), 2 deletions(-)
 create mode 100644 llvm/lib/CodeGen/StaticDataSplitter.cpp
 create mode 100644 llvm/test/CodeGen/X86/jump-table-partition.ll

diff --git a/llvm/include/llvm/CodeGen/MachineBasicBlock.h b/llvm/include/llvm/CodeGen/MachineBasicBlock.h
index 7fe33c3913f2dd..9ed356c2ab2331 100644
--- a/llvm/include/llvm/CodeGen/MachineBasicBlock.h
+++ b/llvm/include/llvm/CodeGen/MachineBasicBlock.h
@@ -997,6 +997,12 @@ class MachineBasicBlock
   /// no changes occurred in the meantime.
   bool canSplitCriticalEdge(const MachineBasicBlock *Succ) const;
 
+  /// Return an index for MachineJumpTableInfo if \p this basic block ends with
+  /// an indirect jump using a jump table, otherwise -1.
+  /// This function is a thin wrapper and forward calls to the per-target method
+  /// `TargetInstrInfo::getjumpTableIndex`.
+  int getJumpTableIndex() const;
+
   void pop_front() { Insts.pop_front(); }
   void pop_back() { Insts.pop_back(); }
   void push_back(MachineInstr *MI) { Insts.push_back(MI); }
diff --git a/llvm/include/llvm/CodeGen/MachineFunction.h b/llvm/include/llvm/CodeGen/MachineFunction.h
index d696add8a1af53..c0f983d1c6787b 100644
--- a/llvm/include/llvm/CodeGen/MachineFunction.h
+++ b/llvm/include/llvm/CodeGen/MachineFunction.h
@@ -88,6 +88,15 @@ template <> struct ilist_callback_traits<MachineBasicBlock> {
   }
 };
 
+// The hotness of static data tracked by a MachineFunction and not represented
+// as a global object in the module IR / MIR. Typical examples are
+// MachineJumpTableInfo and MachineConstantPool.
+enum class DataHotness {
+  Unknown,
+  Cold,
+  Hot,
+};
+
 /// MachineFunctionInfo - This class can be derived from and used by targets to
 /// hold private target-specific information for each MachineFunction.  Objects
 /// of type are accessed/created with MF::getInfo and destroyed when the
diff --git a/llvm/include/llvm/CodeGen/MachineJumpTableInfo.h b/llvm/include/llvm/CodeGen/MachineJumpTableInfo.h
index e8e9c2f6338e06..cc1f54a81b9bb4 100644
--- a/llvm/include/llvm/CodeGen/MachineJumpTableInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineJumpTableInfo.h
@@ -28,6 +28,7 @@ namespace llvm {
 class MachineBasicBlock;
 class DataLayout;
 class raw_ostream;
+enum class DataHotness;
 
 /// MachineJumpTableEntry - One jump table in the jump table info.
 ///
@@ -35,8 +36,9 @@ struct MachineJumpTableEntry {
   /// MBBs - The vector of basic blocks from which to create the jump table.
   std::vector<MachineBasicBlock*> MBBs;
 
-  explicit MachineJumpTableEntry(const std::vector<MachineBasicBlock*> &M)
-  : MBBs(M) {}
+  DataHotness Hotness;
+
+  explicit MachineJumpTableEntry(const std::vector<MachineBasicBlock *> &M);
 };
 
 class MachineJumpTableInfo {
@@ -107,6 +109,8 @@ class MachineJumpTableInfo {
     return JumpTables;
   }
 
+  void updateJumpTableHotness(size_t JTI, DataHotness Hotness);
+
   /// RemoveJumpTable - Mark the specific index as being dead.  This will
   /// prevent it from being emitted.
   void RemoveJumpTable(unsigned Idx) {
diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h
index d1fac4a304cffe..16423d03ff7018 100644
--- a/llvm/include/llvm/CodeGen/Passes.h
+++ b/llvm/include/llvm/CodeGen/Passes.h
@@ -71,6 +71,10 @@ namespace llvm {
   /// using profile information.
   MachineFunctionPass *createMachineFunctionSplitterPass();
 
+  /// createStaticDataSplitterPass - This pass partions static data sections
+  /// into a hot and cold section using profile information.
+  MachineFunctionPass *createStaticDataSplitterPass();
+
   /// MachineFunctionPrinter pass - This pass prints out the machine function to
   /// the given stream as a debugging tool.
   MachineFunctionPass *
diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h
index 1cb9013bc48cc5..8111afcc1fb20f 100644
--- a/llvm/include/llvm/InitializePasses.h
+++ b/llvm/include/llvm/InitializePasses.h
@@ -293,6 +293,7 @@ void initializeSpeculativeExecutionLegacyPassPass(PassRegistry &);
 void initializeSpillPlacementWrapperLegacyPass(PassRegistry &);
 void initializeStackColoringLegacyPass(PassRegistry &);
 void initializeStackFrameLayoutAnalysisPassPass(PassRegistry &);
+void initializeStaticDataSplitterPass(PassRegistry &);
 void initializeStackMapLivenessPass(PassRegistry &);
 void initializeStackProtectorPass(PassRegistry &);
 void initializeStackSafetyGlobalInfoWrapperPassPass(PassRegistry &);
diff --git a/llvm/include/llvm/Passes/MachinePassRegistry.def b/llvm/include/llvm/Passes/MachinePassRegistry.def
index 5a4e79d7225db1..f4cddafa009711 100644
--- a/llvm/include/llvm/Passes/MachinePassRegistry.def
+++ b/llvm/include/llvm/Passes/MachinePassRegistry.def
@@ -236,6 +236,7 @@ DUMMY_MACHINE_FUNCTION_PASS("livedebugvalues", LiveDebugValuesPass)
 DUMMY_MACHINE_FUNCTION_PASS("lrshrink", LiveRangeShrinkPass)
 DUMMY_MACHINE_FUNCTION_PASS("machine-combiner", MachineCombinerPass)
 DUMMY_MACHINE_FUNCTION_PASS("machine-cp", MachineCopyPropagationPass)
+DUMMY_MACHINE_FUNCTION_PASS("static-data-splitter", StaticDataSplitter)
 DUMMY_MACHINE_FUNCTION_PASS("machine-function-splitter", MachineFunctionSplitterPass)
 DUMMY_MACHINE_FUNCTION_PASS("machine-latecleanup", MachineLateInstrsCleanupPass)
 DUMMY_MACHINE_FUNCTION_PASS("machine-sanmd", MachineSanitizerBinaryMetadata)
diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt
index 145fd2fac8b564..88f863d8204d09 100644
--- a/llvm/lib/CodeGen/CMakeLists.txt
+++ b/llvm/lib/CodeGen/CMakeLists.txt
@@ -226,6 +226,7 @@ add_llvm_component_library(LLVMCodeGen
   StackMaps.cpp
   StackProtector.cpp
   StackSlotColoring.cpp
+  StaticDataSplitter.cpp
   SwiftErrorValueTracking.cpp
   SwitchLoweringUtils.cpp
   TailDuplication.cpp
diff --git a/llvm/lib/CodeGen/CodeGen.cpp b/llvm/lib/CodeGen/CodeGen.cpp
index 8efe540770913a..84d92705de0223 100644
--- a/llvm/lib/CodeGen/CodeGen.cpp
+++ b/llvm/lib/CodeGen/CodeGen.cpp
@@ -130,6 +130,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) {
   initializeStackMapLivenessPass(Registry);
   initializeStackProtectorPass(Registry);
   initializeStackSlotColoringPass(Registry);
+  initializeStaticDataSplitterPass(Registry);
   initializeStripDebugMachineModulePass(Registry);
   initializeTailDuplicateLegacyPass(Registry);
   initializeTargetPassConfigPass(Registry);
diff --git a/llvm/lib/CodeGen/MachineBasicBlock.cpp b/llvm/lib/CodeGen/MachineBasicBlock.cpp
index 5ac6472a01e9fc..af298c5994fbb6 100644
--- a/llvm/lib/CodeGen/MachineBasicBlock.cpp
+++ b/llvm/lib/CodeGen/MachineBasicBlock.cpp
@@ -1426,6 +1426,10 @@ bool MachineBasicBlock::canSplitCriticalEdge(
   return true;
 }
 
+int MachineBasicBlock::getJumpTableIndex() const {
+  return findJumpTableIndex(*this);
+}
+
 /// Prepare MI to be removed from its bundle. This fixes bundle flags on MI's
 /// neighboring instructions so the bundle won't be broken by removing MI.
 static void unbundleSingleMI(MachineInstr *MI) {
diff --git a/llvm/lib/CodeGen/MachineFunction.cpp b/llvm/lib/CodeGen/MachineFunction.cpp
index e6b9538fe9a02c..b5a89f3bcf42f1 100644
--- a/llvm/lib/CodeGen/MachineFunction.cpp
+++ b/llvm/lib/CodeGen/MachineFunction.cpp
@@ -1291,6 +1291,10 @@ const unsigned MachineFunction::DebugOperandMemNumber = 1000000;
 //  MachineJumpTableInfo implementation
 //===----------------------------------------------------------------------===//
 
+MachineJumpTableEntry::MachineJumpTableEntry(
+    const std::vector<MachineBasicBlock *> &MBBs)
+    : MBBs(MBBs), Hotness(DataHotness::Unknown) {}
+
 /// Return the size of each entry in the jump table.
 unsigned MachineJumpTableInfo::getEntrySize(const DataLayout &TD) const {
   // The size of a jump table entry is 4 bytes unless the entry is just the
@@ -1340,6 +1344,15 @@ unsigned MachineJumpTableInfo::createJumpTableIndex(
   return JumpTables.size()-1;
 }
 
+void MachineJumpTableInfo::updateJumpTableHotness(size_t JTI,
+                                                  DataHotness Hotness) {
+  assert(JTI < JumpTables.size() && "Invalid JTI!");
+  // Note record the largest hotness is important for mergable data (constant
+  // pools). Even if jump table instances are not merged, record the largest
+  // value seen fwiw.
+  JumpTables[JTI].Hotness = std::max(JumpTables[JTI].Hotness, Hotness);
+}
+
 /// If Old is the target of any jump tables, update the jump tables to branch
 /// to New instead.
 bool MachineJumpTableInfo::ReplaceMBBInJumpTables(MachineBasicBlock *Old,
diff --git a/llvm/lib/CodeGen/StaticDataSplitter.cpp b/llvm/lib/CodeGen/StaticDataSplitter.cpp
new file mode 100644
index 00000000000000..14b9c1b3394d2e
--- /dev/null
+++ b/llvm/lib/CodeGen/StaticDataSplitter.cpp
@@ -0,0 +1,153 @@
+//===- StaticDataSplitter.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass uses profile information to partition static data sections into
+// hot and cold ones. It begins to split jump tables based on profile, and
+// subsequent patches will handle constant pools and other module internal data.
+//
+// For the original RFC of this pass please see
+// https://discourse.llvm.org/t/rfc-profile-guided-static-data-partitioning/83744.
+
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/ProfileSummaryInfo.h"
+#include "llvm/CodeGen/MBFIWrapper.h"
+#include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
+#include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
+#include "llvm/CodeGen/MachineConstantPool.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineJumpTableInfo.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "static-data-splitter"
+
+STATISTIC(NumHotJumpTables, "Number of hot jump tables seen");
+STATISTIC(NumColdJumpTables, "Number of cold jump tables seen");
+STATISTIC(NumUnknownJumpTables,
+          "Number of jump tables with unknown hotness. Such jump tables will "
+          "be placed in the hot-suffixed section by default.");
+
+class StaticDataSplitter : public MachineFunctionPass {
+  const MachineBranchProbabilityInfo *MBPI = nullptr;
+  const MachineBlockFrequencyInfo *MBFI = nullptr;
+  const ProfileSummaryInfo *PSI = nullptr;
+
+  // Returns true iff any jump table is hot-cold categorized.
+  bool splitJumpTables(MachineFunction &MF);
+
+  // Same as above but works on functions with profile information.
+  bool splitJumpTablesWithProfiles(MachineFunction &MF,
+                                   MachineJumpTableInfo &MJTI);
+
+public:
+  static char ID;
+
+  StaticDataSplitter() : MachineFunctionPass(ID) {
+    initializeStaticDataSplitterPass(*PassRegistry::getPassRegistry());
+  }
+
+  StringRef getPassName() const override { return "Static Data Splitter"; }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    MachineFunctionPass::getAnalysisUsage(AU);
+    AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
+    AU.addRequired<MachineBlockFrequencyInfoWrapperPass>();
+    AU.addRequired<ProfileSummaryInfoWrapperPass>();
+  }
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+};
+
+bool StaticDataSplitter::runOnMachineFunction(MachineFunction &MF) {
+  MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
+  MBFI = &getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI();
+  PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
+
+  // Split jump tables based on profile information. Subsequent patches will
+  // handle other data types like constant pools, module-internal data, etc.
+  return splitJumpTables(MF);
+}
+
+bool StaticDataSplitter::splitJumpTablesWithProfiles(
+    MachineFunction &MF, MachineJumpTableInfo &MJTI) {
+  int NumChangedJumpTables = 0;
+  // Regard a jump table as hot by default. If the source and all of destination
+  // blocks are cold, regard the jump table as cold.
+  DataHotness Hotness = DataHotness::Hot;
+  for (const auto &MBB : MF) {
+    // IMPORTANT, `getJumpTableIndex` is a thin wrapper around per-target
+    // interface `TargetInstrInfo::getjumpTableIndex`, and only X86 implements
+    // it so far.
+    const int JTI = MBB.getJumpTableIndex();
+    // This is not a source block of jump table.
+    if (JTI == -1)
+      continue;
+
+    bool AllBlocksCold = true;
+
+    if (!PSI->isColdBlock(&MBB, MBFI))
+      AllBlocksCold = false;
+
+    for (const MachineBasicBlock *MBB : MJTI.getJumpTables()[JTI].MBBs)
+      if (!PSI->isColdBlock(MBB, MBFI))
+        AllBlocksCold = false;
+
+    if (AllBlocksCold) {
+      Hotness = DataHotness::Cold;
+      ++NumColdJumpTables;
+    } else {
+      ++NumHotJumpTables;
+    }
+
+    MF.getJumpTableInfo()->updateJumpTableHotness(JTI, Hotness);
+    ++NumChangedJumpTables;
+  }
+  return NumChangedJumpTables > 0;
+}
+
+bool StaticDataSplitter::splitJumpTables(MachineFunction &MF) {
+  MachineJumpTableInfo *MJTI = MF.getJumpTableInfo();
+  if (!MJTI || MJTI->getJumpTables().empty())
+    return false;
+
+  // Place jump tables according to block hotness if block counters are
+  // available. Check function entry count because BFI depends on it to derive
+  // block counters.
+  if (PSI && PSI->hasProfileSummary() && MBFI &&
+      MF.getFunction().getEntryCount())
+    return splitJumpTablesWithProfiles(MF, *MJTI);
+
+  // Conservatively place all jump tables in the hot-suffixed section if profile
+  // information for the function is not available, or the target doesn't
+  // implement `TargetInstrInfo::getJumpTableIndex` yet.
+  for (size_t JTI = 0; JTI < MJTI->getJumpTables().size(); JTI++)
+    MF.getJumpTableInfo()->updateJumpTableHotness(JTI, DataHotness::Hot);
+
+  NumUnknownJumpTables += MJTI->getJumpTables().size();
+  return true;
+}
+
+char StaticDataSplitter::ID = 0;
+
+INITIALIZE_PASS_BEGIN(StaticDataSplitter, DEBUG_TYPE, "Split static data",
+                      false, false)
+INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
+INITIALIZE_PASS_END(StaticDataSplitter, DEBUG_TYPE, "Split static data", false,
+                    false)
+
+MachineFunctionPass *llvm::createStaticDataSplitterPass() {
+  return new StaticDataSplitter();
+}
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index d407e9f0871d4c..23929672f11d68 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -1256,6 +1256,7 @@ void TargetPassConfig::addMachinePasses() {
                "performance.\n";
       }
     }
+    addPass(createStaticDataSplitterPass());
     addPass(createMachineFunctionSplitterPass());
   }
   // We run the BasicBlockSections pass if either we need BB sections or BB
diff --git a/llvm/test/CodeGen/X86/jump-table-partition.ll b/llvm/test/CodeGen/X86/jump-table-partition.ll
new file mode 100644
index 00000000000000..3a8f1395f6b283
--- /dev/null
+++ b/llvm/test/CodeGen/X86/jump-table-partition.ll
@@ -0,0 +1,163 @@
+;; -stats requires asserts
+; requires: asserts
+
+; RUN: llc -stop-after=block-placement %s -o - | llc --run-pass=static-data-splitter -stats -x mir -o - 2>&1 | FileCheck %s --check-prefix=STAT
+
+; `func_with_hot_jumptable` contains a hot jump table and `func_with_cold_jumptable` contains a cold one. 
+; `func_without_entry_count` simulates the functions without profile information (e.g., not instrumented or not profiled),
+; it's jump table hotness is unknown and regarded as hot conservatively.
+;
+; Tests stat messages are expected.
+; TODO: Update test to verify section suffixes when target-lowering and assembler changes are implemented.
+;
+; STAT-DAG: 1 static-data-splitter - Number of cold jump tables seen
+; STAT-DAG: 1 static-data-splitter - Number of hot jump tables seen
+; STAT-DAG: 1 static-data-splitter - Number of jump tables with unknown hotness
+
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+ at .str.2 = private constant [7 x i8] c"case 3\00"
+ at .str.3 = private constant [7 x i8] c"case 4\00"
+ at .str.4 = private constant [7 x i8] c"case 5\00"
+ at str.9 = private constant [7 x i8] c"case 2\00"
+ at str.10 = private constant [7 x i8] c"case 1\00"
+ at str.11 = private constant [8 x i8] c"default\00"
+
+define i32 @func_with_hot_jumptable(i32 %num) !prof !13 {
+entry:
+  switch i32 %num, label %sw.default [
+    i32 1, label %sw.bb
+    i32 2, label %sw.bb1
+    i32 3, label %sw.bb3
+    i32 4, label %sw.bb5
+    i32 5, label %sw.bb7
+  ], !prof !14
+
+sw.bb:                                            ; preds = %entry
+  %puts11 = tail call i32 @puts(ptr @str.10)
+  br label %sw.epilog
+
+sw.bb1:                                           ; preds = %entry
+  %puts = tail call i32 @puts(ptr @str.9)
+  br label %sw.epilog
+
+sw.bb3:                                           ; preds = %entry
+  %call4 = tail call i32 (ptr, ...) @printf(ptr @.str.2)
+  br label %sw.bb5
+
+sw.bb5:                                           ; preds = %entry, %sw.bb3
+  %call6 = tail call i32 (ptr, ...) @printf(ptr @.str.3)
+  br label %sw.bb7
+
+sw.bb7:                                           ; preds = %entry, %sw.bb5
+  %call8 = tail call i32 (ptr, ...) @printf(ptr @.str.4)
+  br label %sw.epilog
+
+sw.default:                                       ; preds = %entry
+  %puts12 = tail call i32 @puts(ptr @str.11)
+  br label %sw.epilog
+
+sw.epilog:                                        ; preds = %sw.default, %sw.bb7, %sw.bb1, %sw.bb
+  %div = sdiv i32 %num, 3
+  ret i32 %div
+}
+
+define void @func_with_cold_jumptable(i32 %num) !prof !15 {
+entry:
+  switch i32 %num, label %sw.default [
+    i32 1, label %sw.bb
+    i32 2, label %sw.bb1
+    i32 3, label %sw.bb3
+    i32 4, label %sw.bb5
+    i32 5, label %sw.bb7
+  ], !prof !16
+
+sw.bb:                                            ; preds = %entry
+  %puts10 = tail call i32 @puts(ptr @str.10)
+  br label %sw.epilog
+
+sw.bb1:                                           ; preds = %entry
+  %puts = tail call i32 @puts(ptr @str.9)
+  br label %sw.epilog
+
+sw.bb3:                                           ; preds = %entry
+  %call4 = tail call i32 (ptr, ...) @printf(ptr @.str.2)
+  br label %sw.bb5
+
+sw.bb5:                                           ; preds = %entry, %sw.bb3
+  %call6 = tail call i32 (ptr, ...) @printf(ptr @.str.3)
+  br label %sw.bb7
+
+sw.bb7:                                           ; preds = %entry, %sw.bb5
+  %call8 = tail call i32 (ptr, ...) @printf(ptr @.str.4)
+  br label %sw.epilog
+
+sw.default:                                       ; preds = %entry
+  %puts11 = tail call i32 @puts(ptr @str.11)
+  br label %sw.epilog
+
+sw.epilog:                                        ; preds = %sw.default, %sw.bb7, %sw.bb1, %sw.bb
+  ret void
+}
+
+define void @func_without_entry_count(i32 %num) {
+entry:
+  switch i32 %num, label %sw.default [
+    i32 1, label %sw.bb
+    i32 2, label %sw.bb1
+    i32 3, label %sw.bb3
+    i32 4, label %sw.bb5
+    i32 5, label %sw.bb7
+  ]
+
+sw.bb:                                            ; preds = %entry
+  %puts10 = tail call i32 @puts(ptr @str.10)
+  br label %sw.epilog
+
+sw.bb1:                                           ; preds = %entry
+  %puts = tail call i32 @puts(ptr @str.9)
+  br label %sw.epilog
+
+sw.bb3:                                           ; preds = %entry
+  %call4 = tail call i32 (ptr, ...) @printf(ptr @.str.2)
+  br label %sw.bb5
+
+sw.bb5:                                           ; preds = %entry, %sw.bb3
+  %call6 = tail call i32 (ptr, ...) @printf(ptr @.str.3)
+  br label %sw.bb7
+
+sw.bb7:                                           ; preds = %entry, %sw.bb5
+  %call8 = tail call i32 (ptr, ...) @printf(ptr @.str.4)
+  br label %sw.epilog
+
+sw.default:                                       ; preds = %entry
+  %puts11 = tail call i32 @puts(ptr @str.11)
+  br label %sw.epilog
+
+sw.epilog:                                        ; preds = %sw.default, %sw.bb7, %sw.bb1, %sw.bb
+  ret void
+}
+
+declare i32 @puts(ptr)
+declare i32 @printf(ptr, ...)
+
+!llvm.module.flags = !{!0}
+
+!0 = !{i32 1, !"ProfileSummary", !1}
+!1 = !{!2, !3, !4, !5, !6, !7, !8, !9}
+!2 = !{!"ProfileFormat", !"InstrProf"}
+!3 = !{!"TotalCount", i64 230002}
+!4 = !{!"MaxCount", i64 100000}
+!5 = !{!"MaxInternalCount", i64 50000}
+!6 = !{!"MaxFunctionCount", i64 100000}
+!7 = !{!"NumCounts", i64 14}
+!8 = !{!"NumFunctions", i64 3}
+!9 = !{!"DetailedSummary", !10}
+!10 = !{!11, !12}
+!11 = !{i32 990000, i64 10000, i32 7}
+!12 = !{i32 999999, i64 1, i32 9}
+!13 = !{!"function_entry_count", i64 100000}
+!14 = !{!"branch_weights", i32 50000, i32 10000, i32 10000, i32 10000, i32 10000, i32 10000}
+!15 = !{!"function_entry_count", i64 1}
+!16 = !{!"branch_weights", i32 1, i32 0, i32 0, i32 0, i32 0, i32 0}

>From 34b6b9b45564d994844cc9610edddf026a0e49cc Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Thu, 9 Jan 2025 11:09:45 -0800
Subject: [PATCH 02/17] Flag-gate the new pass and resolve review feedback

---
 llvm/lib/CodeGen/StaticDataSplitter.cpp | 11 ++++++-----
 llvm/lib/CodeGen/TargetPassConfig.cpp   |  8 +++++++-
 2 files changed, 13 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/CodeGen/StaticDataSplitter.cpp b/llvm/lib/CodeGen/StaticDataSplitter.cpp
index 14b9c1b3394d2e..482a61027cf985 100644
--- a/llvm/lib/CodeGen/StaticDataSplitter.cpp
+++ b/llvm/lib/CodeGen/StaticDataSplitter.cpp
@@ -83,7 +83,10 @@ bool StaticDataSplitter::splitJumpTablesWithProfiles(
     MachineFunction &MF, MachineJumpTableInfo &MJTI) {
   int NumChangedJumpTables = 0;
   // Regard a jump table as hot by default. If the source and all of destination
-  // blocks are cold, regard the jump table as cold.
+  // blocks are cold, regard the jump table as cold. While a destination block
+  // does not read a jump table (unless it's also a source block), a hot
+  // destination heuristically makes its jump table hot to accommodate for
+  // potential profile data skews (from sampled profiles, for example).
   DataHotness Hotness = DataHotness::Hot;
   for (const auto &MBB : MF) {
     // IMPORTANT, `getJumpTableIndex` is a thin wrapper around per-target
@@ -121,11 +124,9 @@ bool StaticDataSplitter::splitJumpTables(MachineFunction &MF) {
   if (!MJTI || MJTI->getJumpTables().empty())
     return false;
 
-  // Place jump tables according to block hotness if block counters are
-  // available. Check function entry count because BFI depends on it to derive
-  // block counters.
+  // Place jump tables according to block hotness if function has profile data.
   if (PSI && PSI->hasProfileSummary() && MBFI &&
-      MF.getFunction().getEntryCount())
+      MF.getFunction().hasProfileData())
     return splitJumpTablesWithProfiles(MF, *MJTI);
 
   // Conservatively place all jump tables in the hot-suffixed section if profile
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index 23929672f11d68..6a964c0910fc61 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -263,6 +263,11 @@ static cl::opt<bool>
     GCEmptyBlocks("gc-empty-basic-blocks", cl::init(false), cl::Hidden,
                   cl::desc("Enable garbage-collecting empty basic blocks"));
 
+static cl::opt<bool>
+    SplitStaticData("split-static-data", cl::Hidden, cl::init(false),
+                    cl::desc("Split static data sections into hot and cold "
+                             "section ones using profile information"));
+
 /// Allow standard passes to be disabled by command line options. This supports
 /// simple binary flags that either suppress the pass or do nothing.
 /// i.e. -disable-mypass=false has no effect.
@@ -1256,8 +1261,9 @@ void TargetPassConfig::addMachinePasses() {
                "performance.\n";
       }
     }
-    addPass(createStaticDataSplitterPass());
     addPass(createMachineFunctionSplitterPass());
+    if (SplitStaticData)
+      addPass(createStaticDataSplitterPass());
   }
   // We run the BasicBlockSections pass if either we need BB sections or BB
   // address map (or both).

>From dd748277dff2b30ed02bfa466eeca7102aa93eb4 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Fri, 10 Jan 2025 13:53:08 -0800
Subject: [PATCH 03/17] rely to upstream

---
 llvm/include/llvm/CodeGen/MachineFunction.h   |   2 +-
 .../llvm/CodeGen/MachineJumpTableInfo.h       |   9 +-
 llvm/include/llvm/CodeGen/Passes.h            |   2 +-
 llvm/lib/CodeGen/MachineFunction.cpp          |  12 +-
 llvm/lib/CodeGen/StaticDataSplitter.cpp       |  87 +++---
 llvm/test/CodeGen/X86/jump-table-partition.ll | 251 +++++++++++-------
 6 files changed, 223 insertions(+), 140 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MachineFunction.h b/llvm/include/llvm/CodeGen/MachineFunction.h
index c0f983d1c6787b..dcdbcaec168d22 100644
--- a/llvm/include/llvm/CodeGen/MachineFunction.h
+++ b/llvm/include/llvm/CodeGen/MachineFunction.h
@@ -91,7 +91,7 @@ template <> struct ilist_callback_traits<MachineBasicBlock> {
 // The hotness of static data tracked by a MachineFunction and not represented
 // as a global object in the module IR / MIR. Typical examples are
 // MachineJumpTableInfo and MachineConstantPool.
-enum class DataHotness {
+enum class MachineFunctionDataHotness {
   Unknown,
   Cold,
   Hot,
diff --git a/llvm/include/llvm/CodeGen/MachineJumpTableInfo.h b/llvm/include/llvm/CodeGen/MachineJumpTableInfo.h
index cc1f54a81b9bb4..e3675d6489b350 100644
--- a/llvm/include/llvm/CodeGen/MachineJumpTableInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineJumpTableInfo.h
@@ -28,7 +28,7 @@ namespace llvm {
 class MachineBasicBlock;
 class DataLayout;
 class raw_ostream;
-enum class DataHotness;
+enum class MachineFunctionDataHotness;
 
 /// MachineJumpTableEntry - One jump table in the jump table info.
 ///
@@ -36,7 +36,7 @@ struct MachineJumpTableEntry {
   /// MBBs - The vector of basic blocks from which to create the jump table.
   std::vector<MachineBasicBlock*> MBBs;
 
-  DataHotness Hotness;
+  MachineFunctionDataHotness Hotness;
 
   explicit MachineJumpTableEntry(const std::vector<MachineBasicBlock *> &M);
 };
@@ -109,7 +109,10 @@ class MachineJumpTableInfo {
     return JumpTables;
   }
 
-  void updateJumpTableHotness(size_t JTI, DataHotness Hotness);
+  // Update machine jump table entry's hotness. Return true if the hotness is
+  // updated.
+  bool updateJumpTableEntryHotness(size_t JTI,
+                                   MachineFunctionDataHotness Hotness);
 
   /// RemoveJumpTable - Mark the specific index as being dead.  This will
   /// prevent it from being emitted.
diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h
index 16423d03ff7018..b5d2a7e6bf035b 100644
--- a/llvm/include/llvm/CodeGen/Passes.h
+++ b/llvm/include/llvm/CodeGen/Passes.h
@@ -71,7 +71,7 @@ namespace llvm {
   /// using profile information.
   MachineFunctionPass *createMachineFunctionSplitterPass();
 
-  /// createStaticDataSplitterPass - This pass partions static data sections
+  /// createStaticDataSplitterPass - This pass partitions a static data section
   /// into a hot and cold section using profile information.
   MachineFunctionPass *createStaticDataSplitterPass();
 
diff --git a/llvm/lib/CodeGen/MachineFunction.cpp b/llvm/lib/CodeGen/MachineFunction.cpp
index b5a89f3bcf42f1..d09e93d79aae6c 100644
--- a/llvm/lib/CodeGen/MachineFunction.cpp
+++ b/llvm/lib/CodeGen/MachineFunction.cpp
@@ -1293,7 +1293,7 @@ const unsigned MachineFunction::DebugOperandMemNumber = 1000000;
 
 MachineJumpTableEntry::MachineJumpTableEntry(
     const std::vector<MachineBasicBlock *> &MBBs)
-    : MBBs(MBBs), Hotness(DataHotness::Unknown) {}
+    : MBBs(MBBs), Hotness(MachineFunctionDataHotness::Unknown) {}
 
 /// Return the size of each entry in the jump table.
 unsigned MachineJumpTableInfo::getEntrySize(const DataLayout &TD) const {
@@ -1344,13 +1344,17 @@ unsigned MachineJumpTableInfo::createJumpTableIndex(
   return JumpTables.size()-1;
 }
 
-void MachineJumpTableInfo::updateJumpTableHotness(size_t JTI,
-                                                  DataHotness Hotness) {
+bool MachineJumpTableInfo::updateJumpTableEntryHotness(
+    size_t JTI, MachineFunctionDataHotness Hotness) {
   assert(JTI < JumpTables.size() && "Invalid JTI!");
   // Note record the largest hotness is important for mergable data (constant
   // pools). Even if jump table instances are not merged, record the largest
   // value seen fwiw.
-  JumpTables[JTI].Hotness = std::max(JumpTables[JTI].Hotness, Hotness);
+  if (Hotness <= JumpTables[JTI].Hotness)
+    return false;
+
+  JumpTables[JTI].Hotness = Hotness;
+  return true;
 }
 
 /// If Old is the target of any jump tables, update the jump tables to branch
diff --git a/llvm/lib/CodeGen/StaticDataSplitter.cpp b/llvm/lib/CodeGen/StaticDataSplitter.cpp
index 482a61027cf985..9e2cfe18256e35 100644
--- a/llvm/lib/CodeGen/StaticDataSplitter.cpp
+++ b/llvm/lib/CodeGen/StaticDataSplitter.cpp
@@ -6,13 +6,16 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This pass uses profile information to partition static data sections into
-// hot and cold ones. It begins to split jump tables based on profile, and
-// subsequent patches will handle constant pools and other module internal data.
+// The pass uses branch profile data to assign hotness based section qualifiers
+// for the following types of static data:
+// - Jump tables
+// - Constant pools (TODO)
+// - Other module-internal data (TODO)
 //
 // For the original RFC of this pass please see
-// https://discourse.llvm.org/t/rfc-profile-guided-static-data-partitioning/83744.
+// https://discourse.llvm.org/t/rfc-profile-guided-static-data-partitioning/83744
 
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/ProfileSummaryInfo.h"
 #include "llvm/CodeGen/MBFIWrapper.h"
@@ -35,8 +38,15 @@ using namespace llvm;
 STATISTIC(NumHotJumpTables, "Number of hot jump tables seen");
 STATISTIC(NumColdJumpTables, "Number of cold jump tables seen");
 STATISTIC(NumUnknownJumpTables,
-          "Number of jump tables with unknown hotness. Such jump tables will "
-          "be placed in the hot-suffixed section by default.");
+          "Number of jump tables with unknown hotness. Option "
+          "-static-data-default-hotness specifies the hotness.");
+
+static cl::opt<MachineFunctionDataHotness> StaticDataDefaultHotness(
+    "static-data-default-hotness", cl::Hidden,
+    cl::desc("The hotness for static data with unknown hotness"),
+    cl::init(MachineFunctionDataHotness::Hot),
+    cl::values(clEnumValN(MachineFunctionDataHotness::Hot, "hot", "Hot"),
+               clEnumValN(MachineFunctionDataHotness::Cold, "cold", "Cold")));
 
 class StaticDataSplitter : public MachineFunctionPass {
   const MachineBranchProbabilityInfo *MBPI = nullptr;
@@ -82,12 +92,7 @@ bool StaticDataSplitter::runOnMachineFunction(MachineFunction &MF) {
 bool StaticDataSplitter::splitJumpTablesWithProfiles(
     MachineFunction &MF, MachineJumpTableInfo &MJTI) {
   int NumChangedJumpTables = 0;
-  // Regard a jump table as hot by default. If the source and all of destination
-  // blocks are cold, regard the jump table as cold. While a destination block
-  // does not read a jump table (unless it's also a source block), a hot
-  // destination heuristically makes its jump table hot to accommodate for
-  // potential profile data skews (from sampled profiles, for example).
-  DataHotness Hotness = DataHotness::Hot;
+
   for (const auto &MBB : MF) {
     // IMPORTANT, `getJumpTableIndex` is a thin wrapper around per-target
     // interface `TargetInstrInfo::getjumpTableIndex`, and only X86 implements
@@ -97,24 +102,14 @@ bool StaticDataSplitter::splitJumpTablesWithProfiles(
     if (JTI == -1)
       continue;
 
-    bool AllBlocksCold = true;
-
-    if (!PSI->isColdBlock(&MBB, MBFI))
-      AllBlocksCold = false;
+    auto Hotness = MachineFunctionDataHotness::Hot;
 
-    for (const MachineBasicBlock *MBB : MJTI.getJumpTables()[JTI].MBBs)
-      if (!PSI->isColdBlock(MBB, MBFI))
-        AllBlocksCold = false;
-
-    if (AllBlocksCold) {
-      Hotness = DataHotness::Cold;
-      ++NumColdJumpTables;
-    } else {
-      ++NumHotJumpTables;
-    }
+    // Hotness is based on source basic block hotness.
+    if (PSI->isColdBlock(&MBB, MBFI))
+      Hotness = MachineFunctionDataHotness::Cold;
 
-    MF.getJumpTableInfo()->updateJumpTableHotness(JTI, Hotness);
-    ++NumChangedJumpTables;
+    if (MF.getJumpTableInfo()->updateJumpTableEntryHotness(JTI, Hotness))
+      ++NumChangedJumpTables;
   }
   return NumChangedJumpTables > 0;
 }
@@ -124,18 +119,40 @@ bool StaticDataSplitter::splitJumpTables(MachineFunction &MF) {
   if (!MJTI || MJTI->getJumpTables().empty())
     return false;
 
+  const bool ProfileAvailable = PSI && PSI->hasProfileSummary() && MBFI &&
+                                MF.getFunction().hasProfileData();
+  auto statOnExit = llvm::make_scope_exit([&] {
+    if (!AreStatisticsEnabled())
+      return;
+
+    if (!ProfileAvailable) {
+      NumUnknownJumpTables += MJTI->getJumpTables().size();
+      return;
+    }
+
+    for (size_t JTI = 0; JTI < MJTI->getJumpTables().size(); JTI++) {
+      auto Hotness = MJTI->getJumpTables()[JTI].Hotness;
+      if (Hotness == MachineFunctionDataHotness::Hot)
+        NumHotJumpTables++;
+      else {
+        assert(Hotness == MachineFunctionDataHotness::Cold &&
+               "A jump table is hot or cold when profile information is "
+               "available.");
+        NumColdJumpTables++;
+      }
+    }
+  });
+
   // Place jump tables according to block hotness if function has profile data.
-  if (PSI && PSI->hasProfileSummary() && MBFI &&
-      MF.getFunction().hasProfileData())
+  if (ProfileAvailable)
     return splitJumpTablesWithProfiles(MF, *MJTI);
 
-  // Conservatively place all jump tables in the hot-suffixed section if profile
-  // information for the function is not available, or the target doesn't
-  // implement `TargetInstrInfo::getJumpTableIndex` yet.
+  // If function profile is unavailable, -static-data-default-hotness specifies
+  // the hotness.
   for (size_t JTI = 0; JTI < MJTI->getJumpTables().size(); JTI++)
-    MF.getJumpTableInfo()->updateJumpTableHotness(JTI, DataHotness::Hot);
+    MF.getJumpTableInfo()->updateJumpTableEntryHotness(
+        JTI, StaticDataDefaultHotness);
 
-  NumUnknownJumpTables += MJTI->getJumpTables().size();
   return true;
 }
 
diff --git a/llvm/test/CodeGen/X86/jump-table-partition.ll b/llvm/test/CodeGen/X86/jump-table-partition.ll
index 3a8f1395f6b283..3b5af2c098553f 100644
--- a/llvm/test/CodeGen/X86/jump-table-partition.ll
+++ b/llvm/test/CodeGen/X86/jump-table-partition.ll
@@ -1,104 +1,172 @@
-;; -stats requires asserts
+; -stats requires asserts
 ; requires: asserts
 
-; RUN: llc -stop-after=block-placement %s -o - | llc --run-pass=static-data-splitter -stats -x mir -o - 2>&1 | FileCheck %s --check-prefix=STAT
+; Stop after 'finalize-isel' for simpler MIR, and lower the minimum number of
+; jump table entries so 'switch' needs fewer cases to generate a jump table.
+; RUN: llc -stop-after=finalize-isel -min-jump-table-entries=2 %s -o %t.mir
+; RUN: llc --run-pass=static-data-splitter -stats -x mir %t.mir -o - 2>&1 | FileCheck %s --check-prefix=STAT
 
-; `func_with_hot_jumptable` contains a hot jump table and `func_with_cold_jumptable` contains a cold one. 
-; `func_without_entry_count` simulates the functions without profile information (e.g., not instrumented or not profiled),
-; it's jump table hotness is unknown and regarded as hot conservatively.
-;
 ; Tests stat messages are expected.
 ; TODO: Update test to verify section suffixes when target-lowering and assembler changes are implemented.
-;
-; STAT-DAG: 1 static-data-splitter - Number of cold jump tables seen
-; STAT-DAG: 1 static-data-splitter - Number of hot jump tables seen
+; TODO: Also run static-data-splitter pass with -static-data-default-hotness=cold and check data section suffix.
+ 
+; STAT-DAG: 2 static-data-splitter - Number of cold jump tables seen
+; STAT-DAG: 3 static-data-splitter - Number of hot jump tables seen
 ; STAT-DAG: 1 static-data-splitter - Number of jump tables with unknown hotness
 
+; @foo has four jump tables, jt0, jt1, jt2 and jt3 in the input basic block
+; order; jt0 and jt2 are hot, and jt1 and jt3 are cold.
+;
+; @func_with_hot_jt is a function with one entry count, and a hot loop using a
+; jump table.
+
+; @func_without_entry_count simulates the functions without profile information
+; (e.g., not instrumented or not profiled), it's jump table hotness is unknown
+; and regarded as hot conservatively.
+
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"
 
- at .str.2 = private constant [7 x i8] c"case 3\00"
- at .str.3 = private constant [7 x i8] c"case 4\00"
- at .str.4 = private constant [7 x i8] c"case 5\00"
- at str.9 = private constant [7 x i8] c"case 2\00"
- at str.10 = private constant [7 x i8] c"case 1\00"
- at str.11 = private constant [8 x i8] c"default\00"
+ at str.9 = private constant [7 x i8] c".str.9\00"
+ at str.10 = private constant [8 x i8] c".str.10\00"
+ at str.11 = private constant [8 x i8] c".str.11\00"
 
-define i32 @func_with_hot_jumptable(i32 %num) !prof !13 {
+ at case2 = private constant [7 x i8] c"case 2\00"
+ at case1 = private constant [7 x i8] c"case 1\00"
+ at default = private constant [8 x i8] c"default\00"
+ at jt3 = private constant [3 x i8] c"jt\00"
+
+; jt0 and jt2 are hot. jt1 and jt3 are cold.
+define i32 @foo(i32 %num) !prof !13 {
 entry:
-  switch i32 %num, label %sw.default [
-    i32 1, label %sw.bb
-    i32 2, label %sw.bb1
-    i32 3, label %sw.bb3
-    i32 4, label %sw.bb5
-    i32 5, label %sw.bb7
+  %mod3 = sdiv i32 %num, 3
+  switch i32 %mod3, label %jt0.default [
+    i32 1, label %jt0.bb1
+    i32 2, label %jt0.bb2
   ], !prof !14
 
-sw.bb:                                            ; preds = %entry
-  %puts11 = tail call i32 @puts(ptr @str.10)
-  br label %sw.epilog
+jt0.bb1:
+  call i32 @puts(ptr @case1)
+  br label %jt0.epilog
 
-sw.bb1:                                           ; preds = %entry
-  %puts = tail call i32 @puts(ptr @str.9)
-  br label %sw.epilog
+jt0.bb2:
+  call i32 @puts(ptr @case2)
+  br label %jt0.epilog
 
-sw.bb3:                                           ; preds = %entry
-  %call4 = tail call i32 (ptr, ...) @printf(ptr @.str.2)
-  br label %sw.bb5
+jt0.default:
+  call i32 @puts(ptr @default)
+  br label %jt0.epilog
 
-sw.bb5:                                           ; preds = %entry, %sw.bb3
-  %call6 = tail call i32 (ptr, ...) @printf(ptr @.str.3)
-  br label %sw.bb7
+jt0.epilog:
+  %zero = icmp eq i32 %num, 0
+  br i1 %zero, label %cold, label %hot, !prof !17
 
-sw.bb7:                                           ; preds = %entry, %sw.bb5
-  %call8 = tail call i32 (ptr, ...) @printf(ptr @.str.4)
-  br label %sw.epilog
+cold:
+  %c1 = call i32 @compute(i32 %num)
+  switch i32 %c1, label %jt1.default [
+    i32 1, label %jt1.bb1
+    i32 2, label %jt1.bb2
+  ], !prof !14
 
-sw.default:                                       ; preds = %entry
-  %puts12 = tail call i32 @puts(ptr @str.11)
-  br label %sw.epilog
+jt1.bb1:
+  call i32 @puts(ptr @case1)
+  br label %jt1.epilog
 
-sw.epilog:                                        ; preds = %sw.default, %sw.bb7, %sw.bb1, %sw.bb
-  %div = sdiv i32 %num, 3
-  ret i32 %div
-}
+jt1.bb2:
+  call i32 @puts(ptr @case2)
+  br label %jt1.epilog
 
-define void @func_with_cold_jumptable(i32 %num) !prof !15 {
-entry:
-  switch i32 %num, label %sw.default [
-    i32 1, label %sw.bb
-    i32 2, label %sw.bb1
-    i32 3, label %sw.bb3
-    i32 4, label %sw.bb5
-    i32 5, label %sw.bb7
-  ], !prof !16
+jt1.default:
+  call i32 @puts(ptr @default)
+  br label %jt1.epilog
 
-sw.bb:                                            ; preds = %entry
-  %puts10 = tail call i32 @puts(ptr @str.10)
-  br label %sw.epilog
+jt1.epilog:
+  br label %return
 
-sw.bb1:                                           ; preds = %entry
-  %puts = tail call i32 @puts(ptr @str.9)
-  br label %sw.epilog
+hot:
+ %c2 = call i32 @transform(i32 %num)
+  switch i32 %c2, label %jt2.default [
+    i32 1, label %jt2.bb1
+    i32 2, label %jt2.bb2
+  ], !prof !14
 
-sw.bb3:                                           ; preds = %entry
-  %call4 = tail call i32 (ptr, ...) @printf(ptr @.str.2)
-  br label %sw.bb5
+jt2.bb1:
+  call i32 @puts(ptr @case1)
+  br label %jt1.epilog
 
-sw.bb5:                                           ; preds = %entry, %sw.bb3
-  %call6 = tail call i32 (ptr, ...) @printf(ptr @.str.3)
-  br label %sw.bb7
+jt2.bb2:
+  call i32 @puts(ptr @case2)
+  br label %jt1.epilog
 
-sw.bb7:                                           ; preds = %entry, %sw.bb5
-  %call8 = tail call i32 (ptr, ...) @printf(ptr @.str.4)
-  br label %sw.epilog
+jt2.default:
+  call i32 @puts(ptr @default)
+  br label %jt2.epilog
 
-sw.default:                                       ; preds = %entry
-  %puts11 = tail call i32 @puts(ptr @str.11)
-  br label %sw.epilog
+jt2.epilog:
+  %c2cmp = icmp ne i32 %c2, 0
+  br i1 %c2cmp, label %return, label %jt3.prologue, !prof !18
 
-sw.epilog:                                        ; preds = %sw.default, %sw.bb7, %sw.bb1, %sw.bb
-  ret void
+jt3.prologue:
+  %c3 = call i32 @cleanup(i32 %num)
+  switch i32 %c3, label %jt3.default [
+    i32 1, label %jt3.bb1
+    i32 2, label %jt3.bb2
+  ], !prof !14
+
+jt3.bb1:
+  call i32 @puts(ptr @case1)
+  br label %jt3.epilog
+
+jt3.bb2:
+  call i32 @puts(ptr @case2)
+  br label %jt3.epilog
+
+jt3.default:
+  call i32 @puts(ptr @default)
+  br label %jt3.epilog
+
+jt3.epilog:
+  call i32 @puts(ptr @jt3)
+  br label %return
+
+return:
+  ret i32 %mod3
+}
+
+define i32 @func_with_hot_jt() !prof !15 {
+entry:
+  br label %for.body
+
+for.cond.cleanup:
+  ret i32 0
+
+for.body:
+  %lsr.iv = phi i32 [ 100000, %entry ], [ %lsr.iv.next, %loop.exit ]
+  %i.04 = phi i32 [ 0, %entry ], [ %inc, %loop.exit ]
+  %0 = urem i32 %i.04, 100
+  switch i32 %0, label %sw.default [
+    i32 1, label %loop.exit
+    i32 2, label %sw.bb1.i
+    i32 3, label %sw.bb3.i
+  ], !prof !19
+
+sw.bb1.i:
+  br label %loop.exit
+
+sw.bb3.i: 
+  call i32 (ptr, ...) @printf(ptr @case1)
+  br label %sw.default
+
+sw.default:
+  br label %loop.exit
+
+loop.exit:
+  %str.5.sink.i = phi ptr [ @str.10, %sw.default ], [ @str.9, %sw.bb1.i ], [ @case2, %for.body ]
+  call i32 @puts(ptr %str.5.sink.i)
+  %inc = add i32 %i.04, 1
+  %lsr.iv.next = add i32 %lsr.iv, -1
+  %exitcond.not = icmp eq i32 %lsr.iv.next, 0
+  br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !prof !17
 }
 
 define void @func_without_entry_count(i32 %num) {
@@ -106,41 +174,29 @@ entry:
   switch i32 %num, label %sw.default [
     i32 1, label %sw.bb
     i32 2, label %sw.bb1
-    i32 3, label %sw.bb3
-    i32 4, label %sw.bb5
-    i32 5, label %sw.bb7
   ]
 
-sw.bb:                                            ; preds = %entry
-  %puts10 = tail call i32 @puts(ptr @str.10)
+sw.bb:
+  call i32 @puts(ptr @str.10)
   br label %sw.epilog
 
-sw.bb1:                                           ; preds = %entry
-  %puts = tail call i32 @puts(ptr @str.9)
-  br label %sw.epilog
-
-sw.bb3:                                           ; preds = %entry
-  %call4 = tail call i32 (ptr, ...) @printf(ptr @.str.2)
-  br label %sw.bb5
-
-sw.bb5:                                           ; preds = %entry, %sw.bb3
-  %call6 = tail call i32 (ptr, ...) @printf(ptr @.str.3)
-  br label %sw.bb7
-
-sw.bb7:                                           ; preds = %entry, %sw.bb5
-  %call8 = tail call i32 (ptr, ...) @printf(ptr @.str.4)
+sw.bb1: 
+  call i32 @puts(ptr @str.9)
   br label %sw.epilog
 
-sw.default:                                       ; preds = %entry
-  %puts11 = tail call i32 @puts(ptr @str.11)
+sw.default:
+  call i32 @puts(ptr @str.11)
   br label %sw.epilog
 
-sw.epilog:                                        ; preds = %sw.default, %sw.bb7, %sw.bb1, %sw.bb
+sw.epilog:                                       
   ret void
 }
 
 declare i32 @puts(ptr)
 declare i32 @printf(ptr, ...)
+declare i32 @compute(i32)
+declare i32 @transform(i32)
+declare i32 @cleanup(i32)
 
 !llvm.module.flags = !{!0}
 
@@ -158,6 +214,9 @@ declare i32 @printf(ptr, ...)
 !11 = !{i32 990000, i64 10000, i32 7}
 !12 = !{i32 999999, i64 1, i32 9}
 !13 = !{!"function_entry_count", i64 100000}
-!14 = !{!"branch_weights", i32 50000, i32 10000, i32 10000, i32 10000, i32 10000, i32 10000}
+!14 = !{!"branch_weights", i32 60000, i32 20000, i32 20000}
 !15 = !{!"function_entry_count", i64 1}
 !16 = !{!"branch_weights", i32 1, i32 0, i32 0, i32 0, i32 0, i32 0}
+!17 = !{!"branch_weights", i32 1, i32 99999}
+!18 = !{!"branch_weights", i32 99998, i32 1}
+!19 = !{!"branch_weights", i32 97000, i32 1000, i32 1000, i32 1000}

>From 8d3a985df083bf766d28e089ce3f7dcab2b53b00 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Fri, 10 Jan 2025 16:31:42 -0800
Subject: [PATCH 04/17] Emit jump table with section suffix

---
 llvm/include/llvm/CodeGen/AsmPrinter.h        |   8 +-
 .../CodeGen/TargetLoweringObjectFileImpl.h    |   3 +
 .../llvm/Target/TargetLoweringObjectFile.h    |   5 +
 llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp    | 108 +++++++++++++-----
 .../CodeGen/TargetLoweringObjectFileImpl.cpp  |  29 +++--
 llvm/lib/CodeGen/TargetPassConfig.cpp         |   2 +-
 llvm/lib/Target/TargetLoweringObjectFile.cpp  |   6 +
 llvm/test/CodeGen/X86/jump-table-partition.ll |  59 ++++++----
 8 files changed, 159 insertions(+), 61 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/AsmPrinter.h b/llvm/include/llvm/CodeGen/AsmPrinter.h
index c9a88d7b1c015c..4a7de22b844a65 100644
--- a/llvm/include/llvm/CodeGen/AsmPrinter.h
+++ b/llvm/include/llvm/CodeGen/AsmPrinter.h
@@ -453,6 +453,10 @@ class AsmPrinter : public MachineFunctionPass {
   /// function to the current output stream.
   virtual void emitJumpTableInfo();
 
+  virtual void emitJumpTables(ArrayRef<unsigned> JumpTableIndices,
+                              MCSection *JumpTableSection, bool JTInDiffSection,
+                              const MachineJumpTableInfo &MJTI);
+
   /// Emit the specified global variable to the .s file.
   virtual void emitGlobalVariable(const GlobalVariable *GV);
 
@@ -892,10 +896,10 @@ class AsmPrinter : public MachineFunctionPass {
   // Internal Implementation Details
   //===------------------------------------------------------------------===//
 
-  void emitJumpTableEntry(const MachineJumpTableInfo *MJTI,
+  void emitJumpTableEntry(const MachineJumpTableInfo &MJTI,
                           const MachineBasicBlock *MBB, unsigned uid) const;
 
-  void emitJumpTableSizesSection(const MachineJumpTableInfo *MJTI,
+  void emitJumpTableSizesSection(const MachineJumpTableInfo &MJTI,
                                  const Function &F) const;
 
   void emitLLVMUsedList(const ConstantArray *InitList);
diff --git a/llvm/include/llvm/CodeGen/TargetLoweringObjectFileImpl.h b/llvm/include/llvm/CodeGen/TargetLoweringObjectFileImpl.h
index a2a9e5d499e527..3d48d380fcb245 100644
--- a/llvm/include/llvm/CodeGen/TargetLoweringObjectFileImpl.h
+++ b/llvm/include/llvm/CodeGen/TargetLoweringObjectFileImpl.h
@@ -74,6 +74,9 @@ class TargetLoweringObjectFileELF : public TargetLoweringObjectFile {
 
   MCSection *getSectionForJumpTable(const Function &F,
                                     const TargetMachine &TM) const override;
+  MCSection *
+  getSectionForJumpTable(const Function &F, const TargetMachine &TM,
+                         const MachineJumpTableEntry *JTE) const override;
   MCSection *getSectionForLSDA(const Function &F, const MCSymbol &FnSym,
                                const TargetMachine &TM) const override;
 
diff --git a/llvm/include/llvm/Target/TargetLoweringObjectFile.h b/llvm/include/llvm/Target/TargetLoweringObjectFile.h
index 4864ba843f4886..577adc458fcbf1 100644
--- a/llvm/include/llvm/Target/TargetLoweringObjectFile.h
+++ b/llvm/include/llvm/Target/TargetLoweringObjectFile.h
@@ -27,6 +27,7 @@ class Function;
 class GlobalObject;
 class GlobalValue;
 class MachineBasicBlock;
+class MachineJumpTableEntry;
 class MachineModuleInfo;
 class Mangler;
 class MCContext;
@@ -132,6 +133,10 @@ class TargetLoweringObjectFile : public MCObjectFileInfo {
 
   virtual MCSection *getSectionForJumpTable(const Function &F,
                                             const TargetMachine &TM) const;
+  virtual MCSection *
+  getSectionForJumpTable(const Function &F, const TargetMachine &TM,
+                         const MachineJumpTableEntry *JTE) const;
+
   virtual MCSection *getSectionForLSDA(const Function &, const MCSymbol &,
                                        const TargetMachine &) const {
     return LSDASection;
diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index d34fe0e86c7495..208c812ac493e1 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -168,6 +168,11 @@ static cl::opt<bool> BBAddrMapSkipEmitBBEntries(
              "unnecessary for some PGOAnalysisMap features."),
     cl::Hidden, cl::init(false));
 
+static cl::opt<bool>
+    EmitStaticDataHotnessSuffix("emit-static-data-hotness-suffix", cl::Hidden,
+                                cl::init(false), cl::ZeroOrMore,
+                                cl::desc("Emit static data hotness suffix"));
+
 static cl::opt<bool> EmitJumpTableSizesSection(
     "emit-jump-table-sizes-section",
     cl::desc("Emit a section containing jump table addresses and sizes"),
@@ -2861,7 +2866,6 @@ void AsmPrinter::emitConstantPool() {
 // Print assembly representations of the jump tables used by the current
 // function.
 void AsmPrinter::emitJumpTableInfo() {
-  const DataLayout &DL = MF->getDataLayout();
   const MachineJumpTableInfo *MJTI = MF->getJumpTableInfo();
   if (!MJTI) return;
   if (MJTI->getEntryKind() == MachineJumpTableInfo::EK_Inline) return;
@@ -2876,42 +2880,94 @@ void AsmPrinter::emitJumpTableInfo() {
       MJTI->getEntryKind() == MachineJumpTableInfo::EK_LabelDifference32 ||
           MJTI->getEntryKind() == MachineJumpTableInfo::EK_LabelDifference64,
       F);
+
+  std::vector<unsigned> JumpTableIndices;
+  if (!EmitStaticDataHotnessSuffix) {
+    for (unsigned JTI = 0, JTSize = JT.size(); JTI < JTSize; ++JTI)
+      JumpTableIndices.push_back(JTI);
+    emitJumpTables(JumpTableIndices, TLOF.getSectionForJumpTable(F, TM),
+                   JTInDiffSection, *MJTI);
+    return;
+  }
+
+  // Iterate all jump tables, put hot jump table indices towards the beginning
+  // of the vector, and cold jump table indices towards the end.
+  int NextHotJumpTableIndex = 0, NextColdJumpTableIndex = JT.size() - 1;
+  JumpTableIndices.resize(JT.size());
+  for (unsigned JTI = 0, JTSize = JT.size(); JTI < JTSize; ++JTI) {
+    if (JT[JTI].Hotness == MachineFunctionDataHotness::Cold)
+      JumpTableIndices[NextColdJumpTableIndex--] = JTI;
+    else
+      JumpTableIndices[NextHotJumpTableIndex++] = JTI;
+  }
+
+  if (NextHotJumpTableIndex != 0) {
+    emitJumpTables(
+        ArrayRef<unsigned>(JumpTableIndices).take_front(NextHotJumpTableIndex),
+        TLOF.getSectionForJumpTable(F, TM, &JT[0]), JTInDiffSection, *MJTI);
+  }
+
+  if (NextHotJumpTableIndex != (int)JT.size()) {
+    // Retain the relative orders of original jump tables.
+    for (int L = NextHotJumpTableIndex, R = JT.size() - 1; L < R; ++L, --R)
+      std::swap(JumpTableIndices[L], JumpTableIndices[R]);
+  
+    emitJumpTables(
+        ArrayRef<unsigned>(JumpTableIndices)
+            .take_back(JT.size() - NextHotJumpTableIndex),
+        TLOF.getSectionForJumpTable(F, TM, &JT[JumpTableIndices[NextHotJumpTableIndex]]),
+        JTInDiffSection, *MJTI);
+  }
+
+  return;
+}
+
+void AsmPrinter::emitJumpTables(ArrayRef<unsigned> JumpTableIndices,
+                                MCSection *JumpTableSection,
+                                bool JTInDiffSection,
+                                const MachineJumpTableInfo &MJTI) {
+  if (JumpTableIndices.empty())
+    return;
+
+  const DataLayout &DL = MF->getDataLayout();
   if (JTInDiffSection) {
-    // Drop it in the readonly section.
-    MCSection *ReadOnlySection = TLOF.getSectionForJumpTable(F, TM);
-    OutStreamer->switchSection(ReadOnlySection);
+    OutStreamer->switchSection(JumpTableSection);
   }
 
-  emitAlignment(Align(MJTI->getEntryAlignment(DL)));
+  emitAlignment(Align(MJTI.getEntryAlignment(MF->getDataLayout())));
 
   // Jump tables in code sections are marked with a data_region directive
   // where that's supported.
   if (!JTInDiffSection)
     OutStreamer->emitDataRegion(MCDR_DataRegionJT32);
 
-  for (unsigned JTI = 0, e = JT.size(); JTI != e; ++JTI) {
-    const std::vector<MachineBasicBlock*> &JTBBs = JT[JTI].MBBs;
+  const auto &JT = MJTI.getJumpTables();
+  for (unsigned Index = 0, e = JumpTableIndices.size(); Index != e; ++Index) {
+    const std::vector<MachineBasicBlock *> &JTBBs =
+        JT[JumpTableIndices[Index]].MBBs;
 
     // If this jump table was deleted, ignore it.
-    if (JTBBs.empty()) continue;
+    if (JTBBs.empty())
+      continue;
 
     // For the EK_LabelDifference32 entry, if using .set avoids a relocation,
     /// emit a .set directive for each unique entry.
-    if (MJTI->getEntryKind() == MachineJumpTableInfo::EK_LabelDifference32 &&
+    if (MJTI.getEntryKind() == MachineJumpTableInfo::EK_LabelDifference32 &&
         MAI->doesSetDirectiveSuppressReloc()) {
-      SmallPtrSet<const MachineBasicBlock*, 16> EmittedSets;
+      SmallPtrSet<const MachineBasicBlock *, 16> EmittedSets;
       const TargetLowering *TLI = MF->getSubtarget().getTargetLowering();
-      const MCExpr *Base = TLI->getPICJumpTableRelocBaseExpr(MF,JTI,OutContext);
+      const MCExpr *Base = TLI->getPICJumpTableRelocBaseExpr(
+          MF, JumpTableIndices[Index], OutContext);
       for (const MachineBasicBlock *MBB : JTBBs) {
         if (!EmittedSets.insert(MBB).second)
           continue;
 
         // .set LJTSet, LBB32-base
         const MCExpr *LHS =
-          MCSymbolRefExpr::create(MBB->getSymbol(), OutContext);
-        OutStreamer->emitAssignment(GetJTSetSymbol(JTI, MBB->getNumber()),
-                                    MCBinaryExpr::createSub(LHS, Base,
-                                                            OutContext));
+            MCSymbolRefExpr::create(MBB->getSymbol(), OutContext);
+        OutStreamer->emitAssignment(
+            GetJTSetSymbol(JumpTableIndices[Index], MBB->getNumber()),
+            MCBinaryExpr::createSub(LHS, Base, OutContext));
       }
     }
 
@@ -2923,27 +2979,27 @@ void AsmPrinter::emitJumpTableInfo() {
       // FIXME: This doesn't have to have any specific name, just any randomly
       // named and numbered local label started with 'l' would work.  Simplify
       // GetJTISymbol.
-      OutStreamer->emitLabel(GetJTISymbol(JTI, true));
+      OutStreamer->emitLabel(GetJTISymbol(JumpTableIndices[Index], true));
 
-    MCSymbol* JTISymbol = GetJTISymbol(JTI);
+    MCSymbol *JTISymbol = GetJTISymbol(JumpTableIndices[Index]);
     OutStreamer->emitLabel(JTISymbol);
 
     // Defer MCAssembler based constant folding due to a performance issue. The
     // label differences will be evaluated at write time.
     for (const MachineBasicBlock *MBB : JTBBs)
-      emitJumpTableEntry(MJTI, MBB, JTI);
+      emitJumpTableEntry(MJTI, MBB, JumpTableIndices[Index]);
   }
 
   if (EmitJumpTableSizesSection)
-    emitJumpTableSizesSection(MJTI, F);
+    emitJumpTableSizesSection(MJTI, MF->getFunction());
 
   if (!JTInDiffSection)
     OutStreamer->emitDataRegion(MCDR_DataRegionEnd);
 }
 
-void AsmPrinter::emitJumpTableSizesSection(const MachineJumpTableInfo *MJTI,
+void AsmPrinter::emitJumpTableSizesSection(const MachineJumpTableInfo &MJTI,
                                            const Function &F) const {
-  const std::vector<MachineJumpTableEntry> &JT = MJTI->getJumpTables();
+  const std::vector<MachineJumpTableEntry> &JT = MJTI.getJumpTables();
 
   if (JT.empty())
     return;
@@ -2991,17 +3047,17 @@ void AsmPrinter::emitJumpTableSizesSection(const MachineJumpTableInfo *MJTI,
 
 /// EmitJumpTableEntry - Emit a jump table entry for the specified MBB to the
 /// current stream.
-void AsmPrinter::emitJumpTableEntry(const MachineJumpTableInfo *MJTI,
+void AsmPrinter::emitJumpTableEntry(const MachineJumpTableInfo &MJTI,
                                     const MachineBasicBlock *MBB,
                                     unsigned UID) const {
   assert(MBB && MBB->getNumber() >= 0 && "Invalid basic block");
   const MCExpr *Value = nullptr;
-  switch (MJTI->getEntryKind()) {
+  switch (MJTI.getEntryKind()) {
   case MachineJumpTableInfo::EK_Inline:
     llvm_unreachable("Cannot emit EK_Inline jump table entry");
   case MachineJumpTableInfo::EK_Custom32:
     Value = MF->getSubtarget().getTargetLowering()->LowerCustomJumpTableEntry(
-        MJTI, MBB, UID, OutContext);
+        &MJTI, MBB, UID, OutContext);
     break;
   case MachineJumpTableInfo::EK_BlockAddress:
     // EK_BlockAddress - Each entry is a plain address of block, e.g.:
@@ -3035,7 +3091,7 @@ void AsmPrinter::emitJumpTableEntry(const MachineJumpTableInfo *MJTI,
     // If the .set directive avoids relocations, this is emitted as:
     //      .set L4_5_set_123, LBB123 - LJTI1_2
     //      .word L4_5_set_123
-    if (MJTI->getEntryKind() == MachineJumpTableInfo::EK_LabelDifference32 &&
+    if (MJTI.getEntryKind() == MachineJumpTableInfo::EK_LabelDifference32 &&
         MAI->doesSetDirectiveSuppressReloc()) {
       Value = MCSymbolRefExpr::create(GetJTSetSymbol(UID, MBB->getNumber()),
                                       OutContext);
@@ -3051,7 +3107,7 @@ void AsmPrinter::emitJumpTableEntry(const MachineJumpTableInfo *MJTI,
 
   assert(Value && "Unknown entry kind!");
 
-  unsigned EntrySize = MJTI->getEntrySize(getDataLayout());
+  unsigned EntrySize = MJTI.getEntrySize(getDataLayout());
   OutStreamer->emitValue(Value, EntrySize);
 }
 
diff --git a/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp b/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp
index be243c0e74e9db..e54acec98864bd 100644
--- a/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp
@@ -24,6 +24,7 @@
 #include "llvm/CodeGen/BasicBlockSectionUtils.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
 #include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/CodeGen/MachineModuleInfoImpls.h"
 #include "llvm/IR/Comdat.h"
@@ -642,9 +643,11 @@ static StringRef getSectionPrefixForGlobal(SectionKind Kind, bool IsLarge) {
 static SmallString<128>
 getELFSectionNameForGlobal(const GlobalObject *GO, SectionKind Kind,
                            Mangler &Mang, const TargetMachine &TM,
-                           unsigned EntrySize, bool UniqueSectionName) {
+                           unsigned EntrySize, bool UniqueSectionName,
+                           const MachineJumpTableEntry *JTE) {
   SmallString<128> Name =
       getSectionPrefixForGlobal(Kind, TM.isLargeGlobalValue(GO));
+
   if (Kind.isMergeableCString()) {
     // We also need alignment here.
     // FIXME: this is getting the alignment of the character, not the
@@ -663,7 +666,12 @@ getELFSectionNameForGlobal(const GlobalObject *GO, SectionKind Kind,
 
   bool HasPrefix = false;
   if (const auto *F = dyn_cast<Function>(GO)) {
-    if (std::optional<StringRef> Prefix = F->getSectionPrefix()) {
+    // Jump table hotness takes precedence over its enclosing function's hotness
+    // if both are available.
+    if (JTE) {
+      if (JTE->Hotness == MachineFunctionDataHotness::Hot)
+        raw_svector_ostream(Name) << ".hot";
+    } else if (std::optional<StringRef> Prefix = F->getSectionPrefix()) {
       raw_svector_ostream(Name) << '.' << *Prefix;
       HasPrefix = true;
     }
@@ -761,8 +769,8 @@ calcUniqueIDUpdateFlagsAndSize(const GlobalObject *GO, StringRef SectionName,
   // implicitly for this symbol e.g. .rodata.str1.1, then we don't need
   // to unique the section as the entry size for this symbol will be
   // compatible with implicitly created sections.
-  SmallString<128> ImplicitSectionNameStem =
-      getELFSectionNameForGlobal(GO, Kind, Mang, TM, EntrySize, false);
+  SmallString<128> ImplicitSectionNameStem = getELFSectionNameForGlobal(
+      GO, Kind, Mang, TM, EntrySize, false, /*MJTE=*/nullptr);
   if (SymbolMergeable &&
       Ctx.isELFImplicitMergeableSectionNamePrefix(SectionName) &&
       SectionName.starts_with(ImplicitSectionNameStem))
@@ -862,7 +870,8 @@ MCSection *TargetLoweringObjectFileELF::getExplicitSectionGlobal(
 static MCSectionELF *selectELFSectionForGlobal(
     MCContext &Ctx, const GlobalObject *GO, SectionKind Kind, Mangler &Mang,
     const TargetMachine &TM, bool EmitUniqueSection, unsigned Flags,
-    unsigned *NextUniqueID, const MCSymbolELF *AssociatedSymbol) {
+    unsigned *NextUniqueID, const MCSymbolELF *AssociatedSymbol,
+    const MachineJumpTableEntry *MJTE = nullptr) {
 
   auto [Group, IsComdat, ExtraFlags] = getGlobalObjectInfo(GO, TM);
   Flags |= ExtraFlags;
@@ -881,7 +890,7 @@ static MCSectionELF *selectELFSectionForGlobal(
     }
   }
   SmallString<128> Name = getELFSectionNameForGlobal(
-      GO, Kind, Mang, TM, EntrySize, UniqueSectionName);
+      GO, Kind, Mang, TM, EntrySize, UniqueSectionName, MJTE);
 
   // Use 0 as the unique ID for execute-only text.
   if (Kind.isExecuteOnly())
@@ -955,6 +964,12 @@ MCSection *TargetLoweringObjectFileELF::getUniqueSectionForFunction(
 
 MCSection *TargetLoweringObjectFileELF::getSectionForJumpTable(
     const Function &F, const TargetMachine &TM) const {
+  return getSectionForJumpTable(F, TM, nullptr);
+}
+
+MCSection *TargetLoweringObjectFileELF::getSectionForJumpTable(
+    const Function &F, const TargetMachine &TM,
+    const MachineJumpTableEntry *JTE) const {
   // If the function can be removed, produce a unique section so that
   // the table doesn't prevent the removal.
   const Comdat *C = F.getComdat();
@@ -965,7 +980,7 @@ MCSection *TargetLoweringObjectFileELF::getSectionForJumpTable(
   return selectELFSectionForGlobal(getContext(), &F, SectionKind::getReadOnly(),
                                    getMangler(), TM, EmitUniqueSection,
                                    ELF::SHF_ALLOC, &NextUniqueID,
-                                   /* AssociatedSymbol */ nullptr);
+                                   /* AssociatedSymbol */ nullptr, JTE);
 }
 
 MCSection *TargetLoweringObjectFileELF::getSectionForLSDA(
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index 6a964c0910fc61..ae8379f5b9b33d 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -1261,9 +1261,9 @@ void TargetPassConfig::addMachinePasses() {
                "performance.\n";
       }
     }
-    addPass(createMachineFunctionSplitterPass());
     if (SplitStaticData)
       addPass(createStaticDataSplitterPass());
+    addPass(createMachineFunctionSplitterPass());
   }
   // We run the BasicBlockSections pass if either we need BB sections or BB
   // address map (or both).
diff --git a/llvm/lib/Target/TargetLoweringObjectFile.cpp b/llvm/lib/Target/TargetLoweringObjectFile.cpp
index 4fe9d13d062265..f21f72ad41e98e 100644
--- a/llvm/lib/Target/TargetLoweringObjectFile.cpp
+++ b/llvm/lib/Target/TargetLoweringObjectFile.cpp
@@ -348,6 +348,12 @@ TargetLoweringObjectFile::SectionForGlobal(const GlobalObject *GO,
 
 MCSection *TargetLoweringObjectFile::getSectionForJumpTable(
     const Function &F, const TargetMachine &TM) const {
+  return getSectionForJumpTable(F, TM, nullptr);
+}
+
+MCSection *TargetLoweringObjectFile::getSectionForJumpTable(
+    const Function &F, const TargetMachine &TM,
+    const MachineJumpTableEntry *JTE) const {
   Align Alignment(1);
   return getSectionForConstant(F.getDataLayout(),
                                SectionKind::getReadOnly(), /*C=*/nullptr,
diff --git a/llvm/test/CodeGen/X86/jump-table-partition.ll b/llvm/test/CodeGen/X86/jump-table-partition.ll
index 3b5af2c098553f..12f9b3d42afd54 100644
--- a/llvm/test/CodeGen/X86/jump-table-partition.ll
+++ b/llvm/test/CodeGen/X86/jump-table-partition.ll
@@ -6,6 +6,8 @@
 ; RUN: llc -stop-after=finalize-isel -min-jump-table-entries=2 %s -o %t.mir
 ; RUN: llc --run-pass=static-data-splitter -stats -x mir %t.mir -o - 2>&1 | FileCheck %s --check-prefix=STAT
 
+; RUN: llc -enable-split-machine-functions -split-static-data -emit-static-data-hotness-suffix=true -function-sections -min-jump-table-entries=2 -disable-block-placement %s -o - 2>&1 | FileCheck %s --check-prefix=SECTION
+
 ; Tests stat messages are expected.
 ; TODO: Update test to verify section suffixes when target-lowering and assembler changes are implemented.
 ; TODO: Also run static-data-splitter pass with -static-data-default-hotness=cold and check data section suffix.
@@ -14,6 +16,13 @@
 ; STAT-DAG: 3 static-data-splitter - Number of hot jump tables seen
 ; STAT-DAG: 1 static-data-splitter - Number of jump tables with unknown hotness
 
+; SECTION: .section .rodata.hot.foo,"a", at progbits
+; SECTION: .LJTI0_0:
+; SECTION: .LJTI0_2:
+; SECTION: .section .rodata.foo,"a", at progbits
+; SECTION: .LJTI0_1:
+; SECTION: .LJTI0_3:
+
 ; @foo has four jump tables, jt0, jt1, jt2 and jt3 in the input basic block
 ; order; jt0 and jt2 are hot, and jt1 and jt3 are cold.
 ;
@@ -59,29 +68,7 @@ jt0.default:
 
 jt0.epilog:
   %zero = icmp eq i32 %num, 0
-  br i1 %zero, label %cold, label %hot, !prof !17
-
-cold:
-  %c1 = call i32 @compute(i32 %num)
-  switch i32 %c1, label %jt1.default [
-    i32 1, label %jt1.bb1
-    i32 2, label %jt1.bb2
-  ], !prof !14
-
-jt1.bb1:
-  call i32 @puts(ptr @case1)
-  br label %jt1.epilog
-
-jt1.bb2:
-  call i32 @puts(ptr @case2)
-  br label %jt1.epilog
-
-jt1.default:
-  call i32 @puts(ptr @default)
-  br label %jt1.epilog
-
-jt1.epilog:
-  br label %return
+  br i1 %zero, label %hot, label %cold, !prof !17
 
 hot:
  %c2 = call i32 @transform(i32 %num)
@@ -106,6 +93,28 @@ jt2.epilog:
   %c2cmp = icmp ne i32 %c2, 0
   br i1 %c2cmp, label %return, label %jt3.prologue, !prof !18
 
+cold:
+  %c1 = call i32 @compute(i32 %num)
+  switch i32 %c1, label %jt1.default [
+    i32 1, label %jt1.bb1
+    i32 2, label %jt1.bb2
+  ], !prof !14
+
+jt1.bb1:
+  call i32 @puts(ptr @case1)
+  br label %jt1.epilog
+
+jt1.bb2:
+  call i32 @puts(ptr @case2)
+  br label %jt1.epilog
+
+jt1.default:
+  call i32 @puts(ptr @default)
+  br label %jt1.epilog
+
+jt1.epilog:
+  br label %return
+
 jt3.prologue:
   %c3 = call i32 @cleanup(i32 %num)
   switch i32 %c3, label %jt3.default [
@@ -166,7 +175,7 @@ loop.exit:
   %inc = add i32 %i.04, 1
   %lsr.iv.next = add i32 %lsr.iv, -1
   %exitcond.not = icmp eq i32 %lsr.iv.next, 0
-  br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !prof !17
+  br i1 %exitcond.not, label %for.body, label %for.cond.cleanup, !prof !17
 }
 
 define void @func_without_entry_count(i32 %num) {
@@ -217,6 +226,6 @@ declare i32 @cleanup(i32)
 !14 = !{!"branch_weights", i32 60000, i32 20000, i32 20000}
 !15 = !{!"function_entry_count", i64 1}
 !16 = !{!"branch_weights", i32 1, i32 0, i32 0, i32 0, i32 0, i32 0}
-!17 = !{!"branch_weights", i32 1, i32 99999}
+!17 = !{!"branch_weights", i32 99999, i32 1}
 !18 = !{!"branch_weights", i32 99998, i32 1}
 !19 = !{!"branch_weights", i32 97000, i32 1000, i32 1000, i32 1000}

>From 1bacc51a5ae03ca9bda60dda4a63de9944d62950 Mon Sep 17 00:00:00 2001
From: Mingming Liu <mingmingl at google.com>
Date: Fri, 10 Jan 2025 17:11:29 -0800
Subject: [PATCH 05/17] Apply suggestions from code review

Co-authored-by: Ellis Hoag <ellis.sparky.hoag at gmail.com>
---
 llvm/lib/CodeGen/MachineFunction.cpp    | 2 +-
 llvm/lib/CodeGen/StaticDataSplitter.cpp | 5 +----
 2 files changed, 2 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/CodeGen/MachineFunction.cpp b/llvm/lib/CodeGen/MachineFunction.cpp
index b5a89f3bcf42f1..117bd988a40af2 100644
--- a/llvm/lib/CodeGen/MachineFunction.cpp
+++ b/llvm/lib/CodeGen/MachineFunction.cpp
@@ -1347,7 +1347,7 @@ unsigned MachineJumpTableInfo::createJumpTableIndex(
 void MachineJumpTableInfo::updateJumpTableHotness(size_t JTI,
                                                   DataHotness Hotness) {
   assert(JTI < JumpTables.size() && "Invalid JTI!");
-  // Note record the largest hotness is important for mergable data (constant
+  // Note: recording the largest hotness is important for mergable data (constant
   // pools). Even if jump table instances are not merged, record the largest
   // value seen fwiw.
   JumpTables[JTI].Hotness = std::max(JumpTables[JTI].Hotness, Hotness);
diff --git a/llvm/lib/CodeGen/StaticDataSplitter.cpp b/llvm/lib/CodeGen/StaticDataSplitter.cpp
index 482a61027cf985..55c4b20b76102c 100644
--- a/llvm/lib/CodeGen/StaticDataSplitter.cpp
+++ b/llvm/lib/CodeGen/StaticDataSplitter.cpp
@@ -97,10 +97,7 @@ bool StaticDataSplitter::splitJumpTablesWithProfiles(
     if (JTI == -1)
       continue;
 
-    bool AllBlocksCold = true;
-
-    if (!PSI->isColdBlock(&MBB, MBFI))
-      AllBlocksCold = false;
+    bool AllBlocksCold = PSI->isColdBlock(&MBB, MBFI);
 
     for (const MachineBasicBlock *MBB : MJTI.getJumpTables()[JTI].MBBs)
       if (!PSI->isColdBlock(MBB, MBFI))

>From 8a85d1aa36e22619231c5e079f369c499dd67f6a Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Fri, 10 Jan 2025 17:17:14 -0800
Subject: [PATCH 06/17] resolve review feedback

---
 llvm/include/llvm/CodeGen/MachineFunction.h   |   2 +-
 .../llvm/CodeGen/MachineJumpTableInfo.h       |   9 +-
 llvm/include/llvm/CodeGen/Passes.h            |   2 +-
 llvm/lib/CodeGen/MachineFunction.cpp          |  14 +-
 llvm/lib/CodeGen/StaticDataSplitter.cpp       |  88 +++++---
 llvm/test/CodeGen/X86/jump-table-partition.ll | 212 ++++++++++--------
 6 files changed, 184 insertions(+), 143 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MachineFunction.h b/llvm/include/llvm/CodeGen/MachineFunction.h
index c0f983d1c6787b..dcdbcaec168d22 100644
--- a/llvm/include/llvm/CodeGen/MachineFunction.h
+++ b/llvm/include/llvm/CodeGen/MachineFunction.h
@@ -91,7 +91,7 @@ template <> struct ilist_callback_traits<MachineBasicBlock> {
 // The hotness of static data tracked by a MachineFunction and not represented
 // as a global object in the module IR / MIR. Typical examples are
 // MachineJumpTableInfo and MachineConstantPool.
-enum class DataHotness {
+enum class MachineFunctionDataHotness {
   Unknown,
   Cold,
   Hot,
diff --git a/llvm/include/llvm/CodeGen/MachineJumpTableInfo.h b/llvm/include/llvm/CodeGen/MachineJumpTableInfo.h
index cc1f54a81b9bb4..e3675d6489b350 100644
--- a/llvm/include/llvm/CodeGen/MachineJumpTableInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineJumpTableInfo.h
@@ -28,7 +28,7 @@ namespace llvm {
 class MachineBasicBlock;
 class DataLayout;
 class raw_ostream;
-enum class DataHotness;
+enum class MachineFunctionDataHotness;
 
 /// MachineJumpTableEntry - One jump table in the jump table info.
 ///
@@ -36,7 +36,7 @@ struct MachineJumpTableEntry {
   /// MBBs - The vector of basic blocks from which to create the jump table.
   std::vector<MachineBasicBlock*> MBBs;
 
-  DataHotness Hotness;
+  MachineFunctionDataHotness Hotness;
 
   explicit MachineJumpTableEntry(const std::vector<MachineBasicBlock *> &M);
 };
@@ -109,7 +109,10 @@ class MachineJumpTableInfo {
     return JumpTables;
   }
 
-  void updateJumpTableHotness(size_t JTI, DataHotness Hotness);
+  // Update machine jump table entry's hotness. Return true if the hotness is
+  // updated.
+  bool updateJumpTableEntryHotness(size_t JTI,
+                                   MachineFunctionDataHotness Hotness);
 
   /// RemoveJumpTable - Mark the specific index as being dead.  This will
   /// prevent it from being emitted.
diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h
index 16423d03ff7018..b5d2a7e6bf035b 100644
--- a/llvm/include/llvm/CodeGen/Passes.h
+++ b/llvm/include/llvm/CodeGen/Passes.h
@@ -71,7 +71,7 @@ namespace llvm {
   /// using profile information.
   MachineFunctionPass *createMachineFunctionSplitterPass();
 
-  /// createStaticDataSplitterPass - This pass partions static data sections
+  /// createStaticDataSplitterPass - This pass partitions a static data section
   /// into a hot and cold section using profile information.
   MachineFunctionPass *createStaticDataSplitterPass();
 
diff --git a/llvm/lib/CodeGen/MachineFunction.cpp b/llvm/lib/CodeGen/MachineFunction.cpp
index 117bd988a40af2..d09e93d79aae6c 100644
--- a/llvm/lib/CodeGen/MachineFunction.cpp
+++ b/llvm/lib/CodeGen/MachineFunction.cpp
@@ -1293,7 +1293,7 @@ const unsigned MachineFunction::DebugOperandMemNumber = 1000000;
 
 MachineJumpTableEntry::MachineJumpTableEntry(
     const std::vector<MachineBasicBlock *> &MBBs)
-    : MBBs(MBBs), Hotness(DataHotness::Unknown) {}
+    : MBBs(MBBs), Hotness(MachineFunctionDataHotness::Unknown) {}
 
 /// Return the size of each entry in the jump table.
 unsigned MachineJumpTableInfo::getEntrySize(const DataLayout &TD) const {
@@ -1344,13 +1344,17 @@ unsigned MachineJumpTableInfo::createJumpTableIndex(
   return JumpTables.size()-1;
 }
 
-void MachineJumpTableInfo::updateJumpTableHotness(size_t JTI,
-                                                  DataHotness Hotness) {
+bool MachineJumpTableInfo::updateJumpTableEntryHotness(
+    size_t JTI, MachineFunctionDataHotness Hotness) {
   assert(JTI < JumpTables.size() && "Invalid JTI!");
-  // Note: recording the largest hotness is important for mergable data (constant
+  // Note record the largest hotness is important for mergable data (constant
   // pools). Even if jump table instances are not merged, record the largest
   // value seen fwiw.
-  JumpTables[JTI].Hotness = std::max(JumpTables[JTI].Hotness, Hotness);
+  if (Hotness <= JumpTables[JTI].Hotness)
+    return false;
+
+  JumpTables[JTI].Hotness = Hotness;
+  return true;
 }
 
 /// If Old is the target of any jump tables, update the jump tables to branch
diff --git a/llvm/lib/CodeGen/StaticDataSplitter.cpp b/llvm/lib/CodeGen/StaticDataSplitter.cpp
index 55c4b20b76102c..82673e851a6817 100644
--- a/llvm/lib/CodeGen/StaticDataSplitter.cpp
+++ b/llvm/lib/CodeGen/StaticDataSplitter.cpp
@@ -6,13 +6,16 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This pass uses profile information to partition static data sections into
-// hot and cold ones. It begins to split jump tables based on profile, and
-// subsequent patches will handle constant pools and other module internal data.
+// The pass uses branch profile data to assign hotness based section qualifiers
+// for the following types of static data:
+// - Jump tables
+// - Constant pools (TODO)
+// - Other module-internal data (TODO)
 //
 // For the original RFC of this pass please see
-// https://discourse.llvm.org/t/rfc-profile-guided-static-data-partitioning/83744.
+// https://discourse.llvm.org/t/rfc-profile-guided-static-data-partitioning/83744
 
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/ProfileSummaryInfo.h"
 #include "llvm/CodeGen/MBFIWrapper.h"
@@ -35,8 +38,16 @@ using namespace llvm;
 STATISTIC(NumHotJumpTables, "Number of hot jump tables seen");
 STATISTIC(NumColdJumpTables, "Number of cold jump tables seen");
 STATISTIC(NumUnknownJumpTables,
-          "Number of jump tables with unknown hotness. Such jump tables will "
-          "be placed in the hot-suffixed section by default.");
+          "Number of jump tables with unknown hotness. Option "
+          "-static-data-default-hotness specifies the hotness.");
+
+static cl::opt<MachineFunctionDataHotness> StaticDataDefaultHotness(
+    "static-data-default-hotness", cl::Hidden,
+    cl::desc("This option specifies the hotness of static data when profile "
+             "information is unavailable"),
+    cl::init(MachineFunctionDataHotness::Hot),
+    cl::values(clEnumValN(MachineFunctionDataHotness::Hot, "hot", "Hot"),
+               clEnumValN(MachineFunctionDataHotness::Cold, "cold", "Cold")));
 
 class StaticDataSplitter : public MachineFunctionPass {
   const MachineBranchProbabilityInfo *MBPI = nullptr;
@@ -74,20 +85,13 @@ bool StaticDataSplitter::runOnMachineFunction(MachineFunction &MF) {
   MBFI = &getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI();
   PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
 
-  // Split jump tables based on profile information. Subsequent patches will
-  // handle other data types like constant pools, module-internal data, etc.
   return splitJumpTables(MF);
 }
 
 bool StaticDataSplitter::splitJumpTablesWithProfiles(
     MachineFunction &MF, MachineJumpTableInfo &MJTI) {
   int NumChangedJumpTables = 0;
-  // Regard a jump table as hot by default. If the source and all of destination
-  // blocks are cold, regard the jump table as cold. While a destination block
-  // does not read a jump table (unless it's also a source block), a hot
-  // destination heuristically makes its jump table hot to accommodate for
-  // potential profile data skews (from sampled profiles, for example).
-  DataHotness Hotness = DataHotness::Hot;
+
   for (const auto &MBB : MF) {
     // IMPORTANT, `getJumpTableIndex` is a thin wrapper around per-target
     // interface `TargetInstrInfo::getjumpTableIndex`, and only X86 implements
@@ -97,21 +101,14 @@ bool StaticDataSplitter::splitJumpTablesWithProfiles(
     if (JTI == -1)
       continue;
 
-    bool AllBlocksCold = PSI->isColdBlock(&MBB, MBFI);
+    auto Hotness = MachineFunctionDataHotness::Hot;
 
-    for (const MachineBasicBlock *MBB : MJTI.getJumpTables()[JTI].MBBs)
-      if (!PSI->isColdBlock(MBB, MBFI))
-        AllBlocksCold = false;
+    // Hotness is based on source basic block hotness.
+    if (PSI->isColdBlock(&MBB, MBFI))
+      Hotness = MachineFunctionDataHotness::Cold;
 
-    if (AllBlocksCold) {
-      Hotness = DataHotness::Cold;
-      ++NumColdJumpTables;
-    } else {
-      ++NumHotJumpTables;
-    }
-
-    MF.getJumpTableInfo()->updateJumpTableHotness(JTI, Hotness);
-    ++NumChangedJumpTables;
+    if (MF.getJumpTableInfo()->updateJumpTableEntryHotness(JTI, Hotness))
+      ++NumChangedJumpTables;
   }
   return NumChangedJumpTables > 0;
 }
@@ -121,18 +118,41 @@ bool StaticDataSplitter::splitJumpTables(MachineFunction &MF) {
   if (!MJTI || MJTI->getJumpTables().empty())
     return false;
 
+  const bool ProfileAvailable = PSI && PSI->hasProfileSummary() && MBFI &&
+                                MF.getFunction().hasProfileData();
+  auto statOnExit = llvm::make_scope_exit([&] {
+    if (!AreStatisticsEnabled())
+      return;
+
+    if (!ProfileAvailable) {
+      NumUnknownJumpTables += MJTI->getJumpTables().size();
+      return;
+    }
+
+    for (size_t JTI = 0; JTI < MJTI->getJumpTables().size(); JTI++) {
+      auto Hotness = MJTI->getJumpTables()[JTI].Hotness;
+      if (Hotness == MachineFunctionDataHotness::Hot)
+        NumHotJumpTables++;
+      else {
+        assert(Hotness == MachineFunctionDataHotness::Cold &&
+               "A jump table is either hot or cold when profile information is "
+               "available.");
+        NumColdJumpTables++;
+      }
+    }
+  });
+
   // Place jump tables according to block hotness if function has profile data.
-  if (PSI && PSI->hasProfileSummary() && MBFI &&
-      MF.getFunction().hasProfileData())
+  if (ProfileAvailable)
     return splitJumpTablesWithProfiles(MF, *MJTI);
 
-  // Conservatively place all jump tables in the hot-suffixed section if profile
-  // information for the function is not available, or the target doesn't
-  // implement `TargetInstrInfo::getJumpTableIndex` yet.
+  // If function profile is unavailable (e.g., module not instrumented, or new
+  // code paths lacking samples), -static-data-default-hotness specifies the
+  // hotness.
   for (size_t JTI = 0; JTI < MJTI->getJumpTables().size(); JTI++)
-    MF.getJumpTableInfo()->updateJumpTableHotness(JTI, DataHotness::Hot);
+    MF.getJumpTableInfo()->updateJumpTableEntryHotness(
+        JTI, StaticDataDefaultHotness);
 
-  NumUnknownJumpTables += MJTI->getJumpTables().size();
   return true;
 }
 
diff --git a/llvm/test/CodeGen/X86/jump-table-partition.ll b/llvm/test/CodeGen/X86/jump-table-partition.ll
index 3a8f1395f6b283..1b717451a86e28 100644
--- a/llvm/test/CodeGen/X86/jump-table-partition.ll
+++ b/llvm/test/CodeGen/X86/jump-table-partition.ll
@@ -1,146 +1,160 @@
-;; -stats requires asserts
+; -stats requires asserts
 ; requires: asserts
 
-; RUN: llc -stop-after=block-placement %s -o - | llc --run-pass=static-data-splitter -stats -x mir -o - 2>&1 | FileCheck %s --check-prefix=STAT
+; Stop after 'finalize-isel' for simpler MIR, and lower the minimum number of
+; jump table entries so 'switch' needs fewer cases to generate a jump table.
+; RUN: llc -stop-after=finalize-isel -min-jump-table-entries=2 %s -o %t.mir
+; RUN: llc --run-pass=static-data-splitter -stats -x mir %t.mir -o - 2>&1 | FileCheck %s --check-prefix=STAT
 
-; `func_with_hot_jumptable` contains a hot jump table and `func_with_cold_jumptable` contains a cold one. 
-; `func_without_entry_count` simulates the functions without profile information (e.g., not instrumented or not profiled),
-; it's jump table hotness is unknown and regarded as hot conservatively.
-;
 ; Tests stat messages are expected.
 ; TODO: Update test to verify section suffixes when target-lowering and assembler changes are implemented.
-;
-; STAT-DAG: 1 static-data-splitter - Number of cold jump tables seen
-; STAT-DAG: 1 static-data-splitter - Number of hot jump tables seen
+; TODO: Also run static-data-splitter pass with -static-data-default-hotness=cold and check data section suffix.
+ 
+; STAT-DAG: 2 static-data-splitter - Number of cold jump tables seen
+; STAT-DAG: 2 static-data-splitter - Number of hot jump tables seen
 ; STAT-DAG: 1 static-data-splitter - Number of jump tables with unknown hotness
 
+; In function @foo, the 2 switch instructions to jt0.* and jt2.* get lowered to hot jump tables,
+; and the 2 switch instructions to jt1.* and jt3.* get lowered to cold jump tables.
+
+; @func_without_profile doesn't have profiles. It's jump table hotness is unknown.
+
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"
 
- at .str.2 = private constant [7 x i8] c"case 3\00"
- at .str.3 = private constant [7 x i8] c"case 4\00"
- at .str.4 = private constant [7 x i8] c"case 5\00"
- at str.9 = private constant [7 x i8] c"case 2\00"
- at str.10 = private constant [7 x i8] c"case 1\00"
- at str.11 = private constant [8 x i8] c"default\00"
+ at str.9 = private constant [7 x i8] c".str.9\00"
+ at str.10 = private constant [8 x i8] c".str.10\00"
+ at str.11 = private constant [8 x i8] c".str.11\00"
 
-define i32 @func_with_hot_jumptable(i32 %num) !prof !13 {
+ at case2 = private constant [7 x i8] c"case 2\00"
+ at case1 = private constant [7 x i8] c"case 1\00"
+ at default = private constant [8 x i8] c"default\00"
+ at jt3 = private constant [4 x i8] c"jt3\00"
+
+define i32 @foo(i32 %num) !prof !13 {
 entry:
-  switch i32 %num, label %sw.default [
-    i32 1, label %sw.bb
-    i32 2, label %sw.bb1
-    i32 3, label %sw.bb3
-    i32 4, label %sw.bb5
-    i32 5, label %sw.bb7
+  %mod3 = sdiv i32 %num, 3
+  switch i32 %mod3, label %jt0.default [
+    i32 1, label %jt0.bb1
+    i32 2, label %jt0.bb2
   ], !prof !14
 
-sw.bb:                                            ; preds = %entry
-  %puts11 = tail call i32 @puts(ptr @str.10)
-  br label %sw.epilog
+jt0.bb1:
+  call i32 @puts(ptr @case1)
+  br label %jt0.epilog
 
-sw.bb1:                                           ; preds = %entry
-  %puts = tail call i32 @puts(ptr @str.9)
-  br label %sw.epilog
+jt0.bb2:
+  call i32 @puts(ptr @case2)
+  br label %jt0.epilog
 
-sw.bb3:                                           ; preds = %entry
-  %call4 = tail call i32 (ptr, ...) @printf(ptr @.str.2)
-  br label %sw.bb5
+jt0.default:
+  call i32 @puts(ptr @default)
+  br label %jt0.epilog
 
-sw.bb5:                                           ; preds = %entry, %sw.bb3
-  %call6 = tail call i32 (ptr, ...) @printf(ptr @.str.3)
-  br label %sw.bb7
+jt0.epilog:
+  %zero = icmp eq i32 %num, 0
+  br i1 %zero, label %hot, label %cold, !prof !15
 
-sw.bb7:                                           ; preds = %entry, %sw.bb5
-  %call8 = tail call i32 (ptr, ...) @printf(ptr @.str.4)
-  br label %sw.epilog
+hot:
+ %c2 = call i32 @transform(i32 %num)
+  switch i32 %c2, label %jt2.default [
+    i32 1, label %jt2.bb1
+    i32 2, label %jt2.bb2
+  ], !prof !14
 
-sw.default:                                       ; preds = %entry
-  %puts12 = tail call i32 @puts(ptr @str.11)
-  br label %sw.epilog
+jt2.bb1:
+  call i32 @puts(ptr @case1)
+  br label %jt1.epilog
 
-sw.epilog:                                        ; preds = %sw.default, %sw.bb7, %sw.bb1, %sw.bb
-  %div = sdiv i32 %num, 3
-  ret i32 %div
-}
+jt2.bb2:
+  call i32 @puts(ptr @case2)
+  br label %jt1.epilog
 
-define void @func_with_cold_jumptable(i32 %num) !prof !15 {
-entry:
-  switch i32 %num, label %sw.default [
-    i32 1, label %sw.bb
-    i32 2, label %sw.bb1
-    i32 3, label %sw.bb3
-    i32 4, label %sw.bb5
-    i32 5, label %sw.bb7
-  ], !prof !16
+jt2.default:
+  call i32 @puts(ptr @default)
+  br label %jt2.epilog
 
-sw.bb:                                            ; preds = %entry
-  %puts10 = tail call i32 @puts(ptr @str.10)
-  br label %sw.epilog
+jt2.epilog:
+  %c2cmp = icmp ne i32 %c2, 0
+  br i1 %c2cmp, label %return, label %jt3.prologue, !prof !16
 
-sw.bb1:                                           ; preds = %entry
-  %puts = tail call i32 @puts(ptr @str.9)
-  br label %sw.epilog
+cold:
+  %c1 = call i32 @compute(i32 %num)
+  switch i32 %c1, label %jt1.default [
+    i32 1, label %jt1.bb1
+    i32 2, label %jt1.bb2
+  ], !prof !14
 
-sw.bb3:                                           ; preds = %entry
-  %call4 = tail call i32 (ptr, ...) @printf(ptr @.str.2)
-  br label %sw.bb5
+jt1.bb1:
+  call i32 @puts(ptr @case1)
+  br label %jt1.epilog
 
-sw.bb5:                                           ; preds = %entry, %sw.bb3
-  %call6 = tail call i32 (ptr, ...) @printf(ptr @.str.3)
-  br label %sw.bb7
+jt1.bb2:
+  call i32 @puts(ptr @case2)
+  br label %jt1.epilog
 
-sw.bb7:                                           ; preds = %entry, %sw.bb5
-  %call8 = tail call i32 (ptr, ...) @printf(ptr @.str.4)
-  br label %sw.epilog
+jt1.default:
+  call i32 @puts(ptr @default)
+  br label %jt1.epilog
 
-sw.default:                                       ; preds = %entry
-  %puts11 = tail call i32 @puts(ptr @str.11)
-  br label %sw.epilog
+jt1.epilog:
+  br label %return
 
-sw.epilog:                                        ; preds = %sw.default, %sw.bb7, %sw.bb1, %sw.bb
-  ret void
+jt3.prologue:
+  %c3 = call i32 @cleanup(i32 %num)
+  switch i32 %c3, label %jt3.default [
+    i32 1, label %jt3.bb1
+    i32 2, label %jt3.bb2
+  ], !prof !14
+
+jt3.bb1:
+  call i32 @puts(ptr @case1)
+  br label %jt3.epilog
+
+jt3.bb2:
+  call i32 @puts(ptr @case2)
+  br label %jt3.epilog
+
+jt3.default:
+  call i32 @puts(ptr @default)
+  br label %jt3.epilog
+
+jt3.epilog:
+  call i32 @puts(ptr @jt3)
+  br label %return
+
+return:
+  ret i32 %mod3
 }
 
-define void @func_without_entry_count(i32 %num) {
+define void @func_without_profile(i32 %num) {
 entry:
   switch i32 %num, label %sw.default [
     i32 1, label %sw.bb
     i32 2, label %sw.bb1
-    i32 3, label %sw.bb3
-    i32 4, label %sw.bb5
-    i32 5, label %sw.bb7
   ]
 
-sw.bb:                                            ; preds = %entry
-  %puts10 = tail call i32 @puts(ptr @str.10)
+sw.bb:
+  call i32 @puts(ptr @str.10)
   br label %sw.epilog
 
-sw.bb1:                                           ; preds = %entry
-  %puts = tail call i32 @puts(ptr @str.9)
-  br label %sw.epilog
-
-sw.bb3:                                           ; preds = %entry
-  %call4 = tail call i32 (ptr, ...) @printf(ptr @.str.2)
-  br label %sw.bb5
-
-sw.bb5:                                           ; preds = %entry, %sw.bb3
-  %call6 = tail call i32 (ptr, ...) @printf(ptr @.str.3)
-  br label %sw.bb7
-
-sw.bb7:                                           ; preds = %entry, %sw.bb5
-  %call8 = tail call i32 (ptr, ...) @printf(ptr @.str.4)
+sw.bb1: 
+  call i32 @puts(ptr @str.9)
   br label %sw.epilog
 
-sw.default:                                       ; preds = %entry
-  %puts11 = tail call i32 @puts(ptr @str.11)
+sw.default:
+  call i32 @puts(ptr @str.11)
   br label %sw.epilog
 
-sw.epilog:                                        ; preds = %sw.default, %sw.bb7, %sw.bb1, %sw.bb
+sw.epilog:                                       
   ret void
 }
 
 declare i32 @puts(ptr)
 declare i32 @printf(ptr, ...)
+declare i32 @compute(i32)
+declare i32 @transform(i32)
+declare i32 @cleanup(i32)
 
 !llvm.module.flags = !{!0}
 
@@ -158,6 +172,6 @@ declare i32 @printf(ptr, ...)
 !11 = !{i32 990000, i64 10000, i32 7}
 !12 = !{i32 999999, i64 1, i32 9}
 !13 = !{!"function_entry_count", i64 100000}
-!14 = !{!"branch_weights", i32 50000, i32 10000, i32 10000, i32 10000, i32 10000, i32 10000}
-!15 = !{!"function_entry_count", i64 1}
-!16 = !{!"branch_weights", i32 1, i32 0, i32 0, i32 0, i32 0, i32 0}
+!14 = !{!"branch_weights", i32 60000, i32 20000, i32 20000}
+!15 = !{!"branch_weights", i32 1, i32 99999}
+!16 = !{!"branch_weights", i32 99998, i32 1}

>From e54dacbbbf94538674868b7a8ae1a86dccac44fb Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Sun, 12 Jan 2025 22:20:27 -0800
Subject: [PATCH 07/17] Discover jump table by calling MachineOperand::isJTI()

---
 llvm/include/llvm/CodeGen/MachineBasicBlock.h |  6 ---
 llvm/lib/CodeGen/MachineBasicBlock.cpp        |  4 --
 llvm/lib/CodeGen/StaticDataSplitter.cpp       | 40 +++++++++++--------
 3 files changed, 24 insertions(+), 26 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MachineBasicBlock.h b/llvm/include/llvm/CodeGen/MachineBasicBlock.h
index 9ed356c2ab2331..7fe33c3913f2dd 100644
--- a/llvm/include/llvm/CodeGen/MachineBasicBlock.h
+++ b/llvm/include/llvm/CodeGen/MachineBasicBlock.h
@@ -997,12 +997,6 @@ class MachineBasicBlock
   /// no changes occurred in the meantime.
   bool canSplitCriticalEdge(const MachineBasicBlock *Succ) const;
 
-  /// Return an index for MachineJumpTableInfo if \p this basic block ends with
-  /// an indirect jump using a jump table, otherwise -1.
-  /// This function is a thin wrapper and forward calls to the per-target method
-  /// `TargetInstrInfo::getjumpTableIndex`.
-  int getJumpTableIndex() const;
-
   void pop_front() { Insts.pop_front(); }
   void pop_back() { Insts.pop_back(); }
   void push_back(MachineInstr *MI) { Insts.push_back(MI); }
diff --git a/llvm/lib/CodeGen/MachineBasicBlock.cpp b/llvm/lib/CodeGen/MachineBasicBlock.cpp
index af298c5994fbb6..5ac6472a01e9fc 100644
--- a/llvm/lib/CodeGen/MachineBasicBlock.cpp
+++ b/llvm/lib/CodeGen/MachineBasicBlock.cpp
@@ -1426,10 +1426,6 @@ bool MachineBasicBlock::canSplitCriticalEdge(
   return true;
 }
 
-int MachineBasicBlock::getJumpTableIndex() const {
-  return findJumpTableIndex(*this);
-}
-
 /// Prepare MI to be removed from its bundle. This fixes bundle flags on MI's
 /// neighboring instructions so the bundle won't be broken by removing MI.
 static void unbundleSingleMI(MachineInstr *MI) {
diff --git a/llvm/lib/CodeGen/StaticDataSplitter.cpp b/llvm/lib/CodeGen/StaticDataSplitter.cpp
index 82673e851a6817..2277a15ea6e3e3 100644
--- a/llvm/lib/CodeGen/StaticDataSplitter.cpp
+++ b/llvm/lib/CodeGen/StaticDataSplitter.cpp
@@ -92,23 +92,31 @@ bool StaticDataSplitter::splitJumpTablesWithProfiles(
     MachineFunction &MF, MachineJumpTableInfo &MJTI) {
   int NumChangedJumpTables = 0;
 
+  // Jump table could be used by either terminating instructions or
+  // non-terminating ones, so we walk all instructions and use
+  // `MachineOperand::isJTI()` to identify jump table operands.
+  // Similarly, `MachineOperand::isCPI()` can identify constant pool usages
+  // in the same loop.
   for (const auto &MBB : MF) {
-    // IMPORTANT, `getJumpTableIndex` is a thin wrapper around per-target
-    // interface `TargetInstrInfo::getjumpTableIndex`, and only X86 implements
-    // it so far.
-    const int JTI = MBB.getJumpTableIndex();
-    // This is not a source block of jump table.
-    if (JTI == -1)
-      continue;
-
-    auto Hotness = MachineFunctionDataHotness::Hot;
-
-    // Hotness is based on source basic block hotness.
-    if (PSI->isColdBlock(&MBB, MBFI))
-      Hotness = MachineFunctionDataHotness::Cold;
-
-    if (MF.getJumpTableInfo()->updateJumpTableEntryHotness(JTI, Hotness))
-      ++NumChangedJumpTables;
+    for (const MachineInstr &I : MBB) {
+      for (const MachineOperand &Op : I.operands()) {
+        if (!Op.isJTI())
+          continue;
+        const int JTI = Op.getIndex();
+        // This is not a source block of jump table.
+        if (JTI == -1)
+          continue;
+
+        auto Hotness = MachineFunctionDataHotness::Hot;
+
+        // Hotness is based on source basic block hotness.
+        if (PSI->isColdBlock(&MBB, MBFI))
+          Hotness = MachineFunctionDataHotness::Cold;
+
+        if (MF.getJumpTableInfo()->updateJumpTableEntryHotness(JTI, Hotness))
+          ++NumChangedJumpTables;
+      }
+    }
   }
   return NumChangedJumpTables > 0;
 }

>From 27ef86dec4a69cbc89f9856e317fbdf65bf9f8d6 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Tue, 14 Jan 2025 13:30:19 -0800
Subject: [PATCH 08/17] Introduce a TargetMachine option , and the command line
 flag option to set it

---
 llvm/include/llvm/CodeGen/CommandFlags.h      |  2 +
 llvm/include/llvm/Target/TargetMachine.h      |  4 +
 llvm/include/llvm/Target/TargetOptions.h      |  3 +
 llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp    |  7 +-
 llvm/lib/CodeGen/CommandFlags.cpp             |  8 ++
 .../CodeGen/TargetLoweringObjectFileImpl.cpp  |  2 +-
 llvm/lib/CodeGen/TargetPassConfig.cpp         |  2 +-
 llvm/test/CodeGen/X86/jump-table-partition.ll | 74 +++++++------------
 8 files changed, 46 insertions(+), 56 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/CommandFlags.h b/llvm/include/llvm/CodeGen/CommandFlags.h
index d5448d781363d4..000aed782a8057 100644
--- a/llvm/include/llvm/CodeGen/CommandFlags.h
+++ b/llvm/include/llvm/CodeGen/CommandFlags.h
@@ -136,6 +136,8 @@ bool getEmitCallSiteInfo();
 
 bool getEnableMachineFunctionSplitter();
 
+bool getEnableStaticDataPartitioning();
+
 bool getEnableDebugEntryValues();
 
 bool getValueTrackingVariableLocations();
diff --git a/llvm/include/llvm/Target/TargetMachine.h b/llvm/include/llvm/Target/TargetMachine.h
index 9bdb110bd36839..4a54c706c0cb6a 100644
--- a/llvm/include/llvm/Target/TargetMachine.h
+++ b/llvm/include/llvm/Target/TargetMachine.h
@@ -305,6 +305,10 @@ class TargetMachine {
     return Options.FunctionSections;
   }
 
+  bool getEnableStaticDataPartitioning() const {
+    return Options.EnableStaticDataPartitioning;
+  }
+
   /// Return true if visibility attribute should not be emitted in XCOFF,
   /// corresponding to -mignore-xcoff-visibility.
   bool getIgnoreXCOFFVisibility() const {
diff --git a/llvm/include/llvm/Target/TargetOptions.h b/llvm/include/llvm/Target/TargetOptions.h
index 88f253805ca99c..1ddee265effa73 100644
--- a/llvm/include/llvm/Target/TargetOptions.h
+++ b/llvm/include/llvm/Target/TargetOptions.h
@@ -312,6 +312,9 @@ namespace llvm {
     /// Enables the MachineFunctionSplitter pass.
     unsigned EnableMachineFunctionSplitter : 1;
 
+    /// Enables the StaticDataSplitter pass.
+    unsigned EnableStaticDataPartitioning : 1;
+
     /// Set if the target supports default outlining behaviour.
     unsigned SupportsDefaultOutlining : 1;
 
diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index 208c812ac493e1..c4629561a4267e 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -168,11 +168,6 @@ static cl::opt<bool> BBAddrMapSkipEmitBBEntries(
              "unnecessary for some PGOAnalysisMap features."),
     cl::Hidden, cl::init(false));
 
-static cl::opt<bool>
-    EmitStaticDataHotnessSuffix("emit-static-data-hotness-suffix", cl::Hidden,
-                                cl::init(false), cl::ZeroOrMore,
-                                cl::desc("Emit static data hotness suffix"));
-
 static cl::opt<bool> EmitJumpTableSizesSection(
     "emit-jump-table-sizes-section",
     cl::desc("Emit a section containing jump table addresses and sizes"),
@@ -2882,7 +2877,7 @@ void AsmPrinter::emitJumpTableInfo() {
       F);
 
   std::vector<unsigned> JumpTableIndices;
-  if (!EmitStaticDataHotnessSuffix) {
+  if (!TM.Options.EnableStaticDataPartitioning) {
     for (unsigned JTI = 0, JTSize = JT.size(); JTI < JTSize; ++JTI)
       JumpTableIndices.push_back(JTI);
     emitJumpTables(JumpTableIndices, TLOF.getSectionForJumpTable(F, TM),
diff --git a/llvm/lib/CodeGen/CommandFlags.cpp b/llvm/lib/CodeGen/CommandFlags.cpp
index d180cfcea658c2..023656cde0089e 100644
--- a/llvm/lib/CodeGen/CommandFlags.cpp
+++ b/llvm/lib/CodeGen/CommandFlags.cpp
@@ -103,6 +103,7 @@ CGOPT(bool, EnableStackSizeSection)
 CGOPT(bool, EnableAddrsig)
 CGOPT(bool, EmitCallSiteInfo)
 CGOPT(bool, EnableMachineFunctionSplitter)
+CGOPT(bool, EnableStaticDataPartitioning)
 CGOPT(bool, EnableDebugEntryValues)
 CGOPT(bool, ForceDwarfFrameSection)
 CGOPT(bool, XRayFunctionIndex)
@@ -480,6 +481,12 @@ codegen::RegisterCodeGenFlags::RegisterCodeGenFlags() {
       cl::init(false));
   CGBINDOPT(EnableMachineFunctionSplitter);
 
+  static cl::opt<bool> EnableStaticDataPartitioning(
+      "partition-static-data-sections",
+      cl::desc("Partition data sections using profile information."),
+      cl::init(false));
+  CGBINDOPT(EnableStaticDataPartitioning);
+
   static cl::opt<bool> ForceDwarfFrameSection(
       "force-dwarf-frame-section",
       cl::desc("Always emit a debug frame section."), cl::init(false));
@@ -586,6 +593,7 @@ codegen::InitTargetOptionsFromCodeGenFlags(const Triple &TheTriple) {
   Options.ExceptionModel = getExceptionModel();
   Options.EmitStackSizeSection = getEnableStackSizeSection();
   Options.EnableMachineFunctionSplitter = getEnableMachineFunctionSplitter();
+  Options.EnableStaticDataPartitioning = getEnableStaticDataPartitioning();
   Options.EmitAddrsig = getEnableAddrsig();
   Options.EmitCallSiteInfo = getEmitCallSiteInfo();
   Options.EnableDebugEntryValues = getEnableDebugEntryValues();
diff --git a/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp b/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp
index e54acec98864bd..5e19df037290a4 100644
--- a/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp
@@ -974,7 +974,7 @@ MCSection *TargetLoweringObjectFileELF::getSectionForJumpTable(
   // the table doesn't prevent the removal.
   const Comdat *C = F.getComdat();
   bool EmitUniqueSection = TM.getFunctionSections() || C;
-  if (!EmitUniqueSection)
+  if (!EmitUniqueSection && !TM.getEnableStaticDataPartitioning())
     return ReadOnlySection;
 
   return selectELFSectionForGlobal(getContext(), &F, SectionKind::getReadOnly(),
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index ae8379f5b9b33d..3f7b2a1b52ac31 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -1261,7 +1261,7 @@ void TargetPassConfig::addMachinePasses() {
                "performance.\n";
       }
     }
-    if (SplitStaticData)
+    if (SplitStaticData || TM->Options.EnableStaticDataPartitioning)
       addPass(createStaticDataSplitterPass());
     addPass(createMachineFunctionSplitterPass());
   }
diff --git a/llvm/test/CodeGen/X86/jump-table-partition.ll b/llvm/test/CodeGen/X86/jump-table-partition.ll
index 12f9b3d42afd54..8c542beef57bb4 100644
--- a/llvm/test/CodeGen/X86/jump-table-partition.ll
+++ b/llvm/test/CodeGen/X86/jump-table-partition.ll
@@ -6,22 +6,36 @@
 ; RUN: llc -stop-after=finalize-isel -min-jump-table-entries=2 %s -o %t.mir
 ; RUN: llc --run-pass=static-data-splitter -stats -x mir %t.mir -o - 2>&1 | FileCheck %s --check-prefix=STAT
 
-; RUN: llc -enable-split-machine-functions -split-static-data -emit-static-data-hotness-suffix=true -function-sections -min-jump-table-entries=2 -disable-block-placement %s -o - 2>&1 | FileCheck %s --check-prefix=SECTION
-
-; Tests stat messages are expected.
-; TODO: Update test to verify section suffixes when target-lowering and assembler changes are implemented.
-; TODO: Also run static-data-splitter pass with -static-data-default-hotness=cold and check data section suffix.
+; When 'partition-static-data-sections' is enabled, static data splitter pass will
+; categorize jump tables and assembly printer will place hot jump tables in the
+; `.hot`-suffixed read only section, and cold ones in the `.rodata` sections.
+; Section names will optionally have `.<func>` if -function-sections is enabled.
+; RUN: llc -enable-split-machine-functions -partition-static-data-sections=true -function-sections=true -min-jump-table-entries=2  %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNC,JT,HOT
+; RUN: llc -enable-split-machine-functions -partition-static-data-sections=true -function-sections=false -min-jump-table-entries=2 %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNCLESS,JT
+
+; Tests that jump tables with unknown hotness are categorized as cold if `-static-data-default-hotness` specifies so.
+; RUN: llc -enable-split-machine-functions -partition-static-data-sections=true -min-jump-table-entries=2 -static-data-default-hotness=cold -function-sections=true %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNC,JT,DEFAULT
  
+ ; Tests stat messages are expected.
 ; STAT-DAG: 2 static-data-splitter - Number of cold jump tables seen
-; STAT-DAG: 3 static-data-splitter - Number of hot jump tables seen
+; STAT-DAG: 2 static-data-splitter - Number of hot jump tables seen
 ; STAT-DAG: 1 static-data-splitter - Number of jump tables with unknown hotness
 
-; SECTION: .section .rodata.hot.foo,"a", at progbits
-; SECTION: .LJTI0_0:
-; SECTION: .LJTI0_2:
-; SECTION: .section .rodata.foo,"a", at progbits
-; SECTION: .LJTI0_1:
-; SECTION: .LJTI0_3:
+; Tests that the first and third jump table are placed in a hot-suffixed section,
+; and the second and fourth are placed in the original section.
+; FUNC: .section .rodata.hot.foo,"a", at progbits
+; FUNCLESS: .section .rodata.hot,"a", at progbits
+; JT: .LJTI0_0:
+; JT: .LJTI0_2:
+; FUNC: .section .rodata.foo,"a", at progbits
+; FUNCLESS: .section .rodata,"a", at progbits
+; JT: .LJTI0_1:
+; JT: .LJTI0_3:
+; HOT: .section .rodata.hot.func_without_entry_count,"a", at progbits
+; HOT: .LJTI1_0:
+
+; DEFAULT: .section .rodata.func_without_entry_count,"a", at progbits
+; DEFAULT: .LJTI1_0:
 
 ; @foo has four jump tables, jt0, jt1, jt2 and jt3 in the input basic block
 ; order; jt0 and jt2 are hot, and jt1 and jt3 are cold.
@@ -142,42 +156,6 @@ return:
   ret i32 %mod3
 }
 
-define i32 @func_with_hot_jt() !prof !15 {
-entry:
-  br label %for.body
-
-for.cond.cleanup:
-  ret i32 0
-
-for.body:
-  %lsr.iv = phi i32 [ 100000, %entry ], [ %lsr.iv.next, %loop.exit ]
-  %i.04 = phi i32 [ 0, %entry ], [ %inc, %loop.exit ]
-  %0 = urem i32 %i.04, 100
-  switch i32 %0, label %sw.default [
-    i32 1, label %loop.exit
-    i32 2, label %sw.bb1.i
-    i32 3, label %sw.bb3.i
-  ], !prof !19
-
-sw.bb1.i:
-  br label %loop.exit
-
-sw.bb3.i: 
-  call i32 (ptr, ...) @printf(ptr @case1)
-  br label %sw.default
-
-sw.default:
-  br label %loop.exit
-
-loop.exit:
-  %str.5.sink.i = phi ptr [ @str.10, %sw.default ], [ @str.9, %sw.bb1.i ], [ @case2, %for.body ]
-  call i32 @puts(ptr %str.5.sink.i)
-  %inc = add i32 %i.04, 1
-  %lsr.iv.next = add i32 %lsr.iv, -1
-  %exitcond.not = icmp eq i32 %lsr.iv.next, 0
-  br i1 %exitcond.not, label %for.body, label %for.cond.cleanup, !prof !17
-}
-
 define void @func_without_entry_count(i32 %num) {
 entry:
   switch i32 %num, label %sw.default [

>From e816defdec877b7c54e27fbe0e51660f67a8e074 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Tue, 14 Jan 2025 17:07:46 -0800
Subject: [PATCH 09/17] Emit jump tables into .hot and .unlikely sections
 respectively

---
 llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp    | 12 ++++++++---
 .../CodeGen/TargetLoweringObjectFileImpl.cpp  |  2 ++
 llvm/test/CodeGen/X86/jump-table-partition.ll | 21 ++++++++++---------
 3 files changed, 22 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index c4629561a4267e..7f80ef6e8a533f 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -2885,8 +2885,14 @@ void AsmPrinter::emitJumpTableInfo() {
     return;
   }
 
+  // When static data partitioning is enabled, collect jump table entries that
+  // go into the same section together to reduce the amount of section switch
+  // statements.
+  //
   // Iterate all jump tables, put hot jump table indices towards the beginning
-  // of the vector, and cold jump table indices towards the end.
+  // of the vector, and cold jump table indices towards the end. Meanwhile
+  // retain the relative orders of original jump tables within a hot or unlikely
+  // section by reversing the cold jump table indices.
   int NextHotJumpTableIndex = 0, NextColdJumpTableIndex = JT.size() - 1;
   JumpTableIndices.resize(JT.size());
   for (unsigned JTI = 0, JTSize = JT.size(); JTI < JTSize; ++JTI) {
@@ -2902,8 +2908,8 @@ void AsmPrinter::emitJumpTableInfo() {
         TLOF.getSectionForJumpTable(F, TM, &JT[0]), JTInDiffSection, *MJTI);
   }
 
-  if (NextHotJumpTableIndex != (int)JT.size()) {
-    // Retain the relative orders of original jump tables.
+  if (NextHotJumpTableIndex < (int)JT.size()) {
+    // Reverse the order of cold jump tables indices.
     for (int L = NextHotJumpTableIndex, R = JT.size() - 1; L < R; ++L, --R)
       std::swap(JumpTableIndices[L], JumpTableIndices[R]);
   
diff --git a/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp b/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp
index 5e19df037290a4..2a00ecf80ac1e2 100644
--- a/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp
@@ -671,6 +671,8 @@ getELFSectionNameForGlobal(const GlobalObject *GO, SectionKind Kind,
     if (JTE) {
       if (JTE->Hotness == MachineFunctionDataHotness::Hot)
         raw_svector_ostream(Name) << ".hot";
+      else if (JTE->Hotness == MachineFunctionDataHotness::Cold)
+        raw_svector_ostream(Name) << ".unlikely";
     } else if (std::optional<StringRef> Prefix = F->getSectionPrefix()) {
       raw_svector_ostream(Name) << '.' << *Prefix;
       HasPrefix = true;
diff --git a/llvm/test/CodeGen/X86/jump-table-partition.ll b/llvm/test/CodeGen/X86/jump-table-partition.ll
index 8c542beef57bb4..6e359c14659125 100644
--- a/llvm/test/CodeGen/X86/jump-table-partition.ll
+++ b/llvm/test/CodeGen/X86/jump-table-partition.ll
@@ -8,13 +8,12 @@
 
 ; When 'partition-static-data-sections' is enabled, static data splitter pass will
 ; categorize jump tables and assembly printer will place hot jump tables in the
-; `.hot`-suffixed read only section, and cold ones in the `.rodata` sections.
+; `.rodata.hot`-prefixed section, and cold ones in the `.rodata.unlikely`-prefixed section.
 ; Section names will optionally have `.<func>` if -function-sections is enabled.
-; RUN: llc -enable-split-machine-functions -partition-static-data-sections=true -function-sections=true -min-jump-table-entries=2  %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNC,JT,HOT
+; RUN: llc -enable-split-machine-functions -partition-static-data-sections=true -function-sections=true -min-jump-table-entries=2  %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNC,JT,DEFAULTHOT
 ; RUN: llc -enable-split-machine-functions -partition-static-data-sections=true -function-sections=false -min-jump-table-entries=2 %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNCLESS,JT
 
-; Tests that jump tables with unknown hotness are categorized as cold if `-static-data-default-hotness` specifies so.
-; RUN: llc -enable-split-machine-functions -partition-static-data-sections=true -min-jump-table-entries=2 -static-data-default-hotness=cold -function-sections=true %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNC,JT,DEFAULT
+; RUN: llc -enable-split-machine-functions -partition-static-data-sections=true -min-jump-table-entries=2 -static-data-default-hotness=cold -function-sections=true %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNC,JT,DEFAULTCOLD
  
  ; Tests stat messages are expected.
 ; STAT-DAG: 2 static-data-splitter - Number of cold jump tables seen
@@ -27,15 +26,17 @@
 ; FUNCLESS: .section .rodata.hot,"a", at progbits
 ; JT: .LJTI0_0:
 ; JT: .LJTI0_2:
-; FUNC: .section .rodata.foo,"a", at progbits
-; FUNCLESS: .section .rodata,"a", at progbits
+; FUNC: .section .rodata.unlikely.foo,"a", at progbits
+; FUNCLESS: .section .rodata.unlikely,"a", at progbits
 ; JT: .LJTI0_1:
 ; JT: .LJTI0_3:
-; HOT: .section .rodata.hot.func_without_entry_count,"a", at progbits
-; HOT: .LJTI1_0:
+; DEFAULTHOT: .section .rodata.hot.func_without_entry_count,"a", at progbits
+; DEFAULTHOT: .LJTI1_0:
+; FUNCLESS: .section .rodata.hot,"a", at progbits
+; FUNCLESS: .LJTI1_0:
 
-; DEFAULT: .section .rodata.func_without_entry_count,"a", at progbits
-; DEFAULT: .LJTI1_0:
+; DEFAULTCOLD: .section .rodata.unlikely.func_without_entry_count,"a", at progbits
+; DEFAULTCOLD: .LJTI1_0:
 
 ; @foo has four jump tables, jt0, jt1, jt2 and jt3 in the input basic block
 ; order; jt0 and jt2 are hot, and jt1 and jt3 are cold.

>From 7f3e4731a6bcda9e6fbdcaeb9a39b1edb47bef48 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Thu, 16 Jan 2025 09:09:40 -0800
Subject: [PATCH 10/17] Add -mtriple=x86_64-unknown-linux-gnu for test

---
 llvm/test/CodeGen/X86/jump-table-partition.ll | 33 ++++++++++---------
 1 file changed, 18 insertions(+), 15 deletions(-)

diff --git a/llvm/test/CodeGen/X86/jump-table-partition.ll b/llvm/test/CodeGen/X86/jump-table-partition.ll
index 6e359c14659125..1a03bd9cc7bb6d 100644
--- a/llvm/test/CodeGen/X86/jump-table-partition.ll
+++ b/llvm/test/CodeGen/X86/jump-table-partition.ll
@@ -3,31 +3,34 @@
 
 ; Stop after 'finalize-isel' for simpler MIR, and lower the minimum number of
 ; jump table entries so 'switch' needs fewer cases to generate a jump table.
-; RUN: llc -stop-after=finalize-isel -min-jump-table-entries=2 %s -o %t.mir
-; RUN: llc --run-pass=static-data-splitter -stats -x mir %t.mir -o - 2>&1 | FileCheck %s --check-prefix=STAT
+; RUN: llc -mtriple=x86_64-unknown-linux-gnu -stop-after=finalize-isel -min-jump-table-entries=2 %s -o %t.mir
+; RUN: llc -mtriple=x86_64-unknown-linux-gnu --run-pass=static-data-splitter -stats -x mir %t.mir -o - 2>&1 | FileCheck %s --check-prefix=STAT
+
+ ; Tests stat messages are expected.
+; STAT-DAG: 2 static-data-splitter - Number of cold jump tables seen
+; STAT-DAG: 2 static-data-splitter - Number of hot jump tables seen
+; STAT-DAG: 1 static-data-splitter - Number of jump tables with unknown hotness
 
 ; When 'partition-static-data-sections' is enabled, static data splitter pass will
 ; categorize jump tables and assembly printer will place hot jump tables in the
 ; `.rodata.hot`-prefixed section, and cold ones in the `.rodata.unlikely`-prefixed section.
 ; Section names will optionally have `.<func>` if -function-sections is enabled.
-; RUN: llc -enable-split-machine-functions -partition-static-data-sections=true -function-sections=true -min-jump-table-entries=2  %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNC,JT,DEFAULTHOT
-; RUN: llc -enable-split-machine-functions -partition-static-data-sections=true -function-sections=false -min-jump-table-entries=2 %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNCLESS,JT
+; RUN: llc -mtriple=x86_64-unknown-linux-gnu -enable-split-machine-functions -partition-static-data-sections=true -function-sections=true -min-jump-table-entries=2 -unique-section-names=false  %s -o - 2>&1 | FileCheck %s --check-prefixes=LINEAR,JT
+; RUN: llc -mtriple=x86_64-unknown-linux-gnu -enable-split-machine-functions -partition-static-data-sections=true -function-sections=true -min-jump-table-entries=2  %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNC,JT,DEFAULTHOT
+; RUN: llc -mtriple=x86_64-unknown-linux-gnu -enable-split-machine-functions -partition-static-data-sections=true -function-sections=false -min-jump-table-entries=2 %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNCLESS,JT --implicit-check-not=unique
 
-; RUN: llc -enable-split-machine-functions -partition-static-data-sections=true -min-jump-table-entries=2 -static-data-default-hotness=cold -function-sections=true %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNC,JT,DEFAULTCOLD
- 
- ; Tests stat messages are expected.
-; STAT-DAG: 2 static-data-splitter - Number of cold jump tables seen
-; STAT-DAG: 2 static-data-splitter - Number of hot jump tables seen
-; STAT-DAG: 1 static-data-splitter - Number of jump tables with unknown hotness
+; Tests that `-static-data-default-hotness` can override hotness for data with
+; unknown hotness.
+; RUN: llc -mtriple=x86_64-unknown-linux-gnu -enable-split-machine-functions -partition-static-data-sections=true -min-jump-table-entries=2 -static-data-default-hotness=cold -function-sections=true %s -o - 2>&1 | FileCheck %s --check-prefixes=FUNC,JT,DEFAULTCOLD
 
-; Tests that the first and third jump table are placed in a hot-suffixed section,
-; and the second and fourth are placed in the original section.
-; FUNC: .section .rodata.hot.foo,"a", at progbits
+; LINEAR:    .section .rodata.hot,"a", at progbits,unique,2
+; FUNC:     .section .rodata.hot.foo,"a", at progbits
 ; FUNCLESS: .section .rodata.hot,"a", at progbits
 ; JT: .LJTI0_0:
 ; JT: .LJTI0_2:
-; FUNC: .section .rodata.unlikely.foo,"a", at progbits
-; FUNCLESS: .section .rodata.unlikely,"a", at progbits
+; LINEAR:    	.section	.rodata.unlikely,"a", at progbits,unique,3
+; FUNC:       .section .rodata.unlikely.foo,"a", at progbits
+; FUNCLESS:   .section .rodata.unlikely,"a", at progbits
 ; JT: .LJTI0_1:
 ; JT: .LJTI0_3:
 ; DEFAULTHOT: .section .rodata.hot.func_without_entry_count,"a", at progbits

>From 1bef9b1b6ed5d94e159d8168dfe34e93d2064b73 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Thu, 16 Jan 2025 09:21:11 -0800
Subject: [PATCH 11/17] run clang format

---
 llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index 7f80ef6e8a533f..4730c33d748832 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -2916,7 +2916,8 @@ void AsmPrinter::emitJumpTableInfo() {
     emitJumpTables(
         ArrayRef<unsigned>(JumpTableIndices)
             .take_back(JT.size() - NextHotJumpTableIndex),
-        TLOF.getSectionForJumpTable(F, TM, &JT[JumpTableIndices[NextHotJumpTableIndex]]),
+        TLOF.getSectionForJumpTable(
+          F, TM, &JT[JumpTableIndices[NextHotJumpTableIndex]]),
         JTInDiffSection, *MJTI);
   }
 

>From 59a958a27602d89466cf13838ed28593fb2b1129 Mon Sep 17 00:00:00 2001
From: Mingming Liu <mingmingl at google.com>
Date: Fri, 17 Jan 2025 14:09:54 -0800
Subject: [PATCH 12/17] Apply suggestions from code review

Co-authored-by: Ellis Hoag <ellis.sparky.hoag at gmail.com>
---
 llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp        | 7 ++++---
 llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp | 5 +++--
 2 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index 4730c33d748832..45ce0928f22657 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -2896,10 +2896,11 @@ void AsmPrinter::emitJumpTableInfo() {
   int NextHotJumpTableIndex = 0, NextColdJumpTableIndex = JT.size() - 1;
   JumpTableIndices.resize(JT.size());
   for (unsigned JTI = 0, JTSize = JT.size(); JTI < JTSize; ++JTI) {
-    if (JT[JTI].Hotness == MachineFunctionDataHotness::Cold)
+    if (JT[JTI].Hotness == MachineFunctionDataHotness::Cold) {
       JumpTableIndices[NextColdJumpTableIndex--] = JTI;
-    else
+    } else {
       JumpTableIndices[NextHotJumpTableIndex++] = JTI;
+    }
   }
 
   if (NextHotJumpTableIndex != 0) {
@@ -2945,7 +2946,7 @@ void AsmPrinter::emitJumpTables(ArrayRef<unsigned> JumpTableIndices,
 
   const auto &JT = MJTI.getJumpTables();
   for (unsigned Index = 0, e = JumpTableIndices.size(); Index != e; ++Index) {
-    const std::vector<MachineBasicBlock *> &JTBBs =
+    ArrayRef<MachineBasicBlock *> JTBBs =
         JT[JumpTableIndices[Index]].MBBs;
 
     // If this jump table was deleted, ignore it.
diff --git a/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp b/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp
index 2a00ecf80ac1e2..fb3ab5400f649e 100644
--- a/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringObjectFileImpl.cpp
@@ -669,10 +669,11 @@ getELFSectionNameForGlobal(const GlobalObject *GO, SectionKind Kind,
     // Jump table hotness takes precedence over its enclosing function's hotness
     // if both are available.
     if (JTE) {
-      if (JTE->Hotness == MachineFunctionDataHotness::Hot)
+      if (JTE->Hotness == MachineFunctionDataHotness::Hot) {
         raw_svector_ostream(Name) << ".hot";
-      else if (JTE->Hotness == MachineFunctionDataHotness::Cold)
+      } else if (JTE->Hotness == MachineFunctionDataHotness::Cold) {
         raw_svector_ostream(Name) << ".unlikely";
+      }
     } else if (std::optional<StringRef> Prefix = F->getSectionPrefix()) {
       raw_svector_ostream(Name) << '.' << *Prefix;
       HasPrefix = true;

>From 2d06092ced408c0fd90b7c82b3990f8284d8f918 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Fri, 17 Jan 2025 14:15:35 -0800
Subject: [PATCH 13/17] fix windows build bot failure

---
 llvm/include/llvm/Target/TargetLoweringObjectFile.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Target/TargetLoweringObjectFile.h b/llvm/include/llvm/Target/TargetLoweringObjectFile.h
index 577adc458fcbf1..2031f5f4498c19 100644
--- a/llvm/include/llvm/Target/TargetLoweringObjectFile.h
+++ b/llvm/include/llvm/Target/TargetLoweringObjectFile.h
@@ -21,13 +21,13 @@
 namespace llvm {
 
 struct Align;
+struct MachineJumpTableEntry;
 class Constant;
 class DataLayout;
 class Function;
 class GlobalObject;
 class GlobalValue;
 class MachineBasicBlock;
-class MachineJumpTableEntry;
 class MachineModuleInfo;
 class Mangler;
 class MCContext;

>From 9134ffa9dac29667374a0939bcf42aac1e0b9413 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Fri, 17 Jan 2025 15:54:26 -0800
Subject: [PATCH 14/17] resolve comments

---
 llvm/include/llvm/CodeGen/AsmPrinter.h     |  8 +--
 llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp | 83 ++++++++++++----------
 2 files changed, 50 insertions(+), 41 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/AsmPrinter.h b/llvm/include/llvm/CodeGen/AsmPrinter.h
index 4a7de22b844a65..dd63197df8ea50 100644
--- a/llvm/include/llvm/CodeGen/AsmPrinter.h
+++ b/llvm/include/llvm/CodeGen/AsmPrinter.h
@@ -453,10 +453,6 @@ class AsmPrinter : public MachineFunctionPass {
   /// function to the current output stream.
   virtual void emitJumpTableInfo();
 
-  virtual void emitJumpTables(ArrayRef<unsigned> JumpTableIndices,
-                              MCSection *JumpTableSection, bool JTInDiffSection,
-                              const MachineJumpTableInfo &MJTI);
-
   /// Emit the specified global variable to the .s file.
   virtual void emitGlobalVariable(const GlobalVariable *GV);
 
@@ -896,6 +892,10 @@ class AsmPrinter : public MachineFunctionPass {
   // Internal Implementation Details
   //===------------------------------------------------------------------===//
 
+  template <typename Iterator>
+  void emitJumpTableImpl(const MachineJumpTableInfo &MJTI,
+                         const llvm::iterator_range<Iterator> &JumpTableIndices,
+                         bool JTInDiffSection);
   void emitJumpTableEntry(const MachineJumpTableInfo &MJTI,
                           const MachineBasicBlock *MBB, unsigned uid) const;
 
diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index 45ce0928f22657..86907011ee029c 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -2880,8 +2880,10 @@ void AsmPrinter::emitJumpTableInfo() {
   if (!TM.Options.EnableStaticDataPartitioning) {
     for (unsigned JTI = 0, JTSize = JT.size(); JTI < JTSize; ++JTI)
       JumpTableIndices.push_back(JTI);
-    emitJumpTables(JumpTableIndices, TLOF.getSectionForJumpTable(F, TM),
-                   JTInDiffSection, *MJTI);
+    emitJumpTableImpl(
+        *MJTI,
+        llvm::make_range(JumpTableIndices.begin(), JumpTableIndices.end()),
+        JTInDiffSection);
     return;
   }
 
@@ -2891,47 +2893,56 @@ void AsmPrinter::emitJumpTableInfo() {
   //
   // Iterate all jump tables, put hot jump table indices towards the beginning
   // of the vector, and cold jump table indices towards the end. Meanwhile
-  // retain the relative orders of original jump tables within a hot or unlikely
-  // section by reversing the cold jump table indices.
-  int NextHotJumpTableIndex = 0, NextColdJumpTableIndex = JT.size() - 1;
+  // retain the relative orders of original jump tables.
+  int NumHotJumpTables = 0, NextColdJumpTableIndex = JT.size() - 1;
   JumpTableIndices.resize(JT.size());
   for (unsigned JTI = 0, JTSize = JT.size(); JTI < JTSize; ++JTI) {
     if (JT[JTI].Hotness == MachineFunctionDataHotness::Cold) {
       JumpTableIndices[NextColdJumpTableIndex--] = JTI;
     } else {
-      JumpTableIndices[NextHotJumpTableIndex++] = JTI;
+      JumpTableIndices[NumHotJumpTables++] = JTI;
     }
   }
 
-  if (NextHotJumpTableIndex != 0) {
-    emitJumpTables(
-        ArrayRef<unsigned>(JumpTableIndices).take_front(NextHotJumpTableIndex),
-        TLOF.getSectionForJumpTable(F, TM, &JT[0]), JTInDiffSection, *MJTI);
-  }
+  emitJumpTableImpl(
+      *MJTI,
+      llvm::make_range(JumpTableIndices.begin(),
+                       JumpTableIndices.begin() + NumHotJumpTables),
 
-  if (NextHotJumpTableIndex < (int)JT.size()) {
-    // Reverse the order of cold jump tables indices.
-    for (int L = NextHotJumpTableIndex, R = JT.size() - 1; L < R; ++L, --R)
-      std::swap(JumpTableIndices[L], JumpTableIndices[R]);
-  
-    emitJumpTables(
-        ArrayRef<unsigned>(JumpTableIndices)
-            .take_back(JT.size() - NextHotJumpTableIndex),
-        TLOF.getSectionForJumpTable(
-          F, TM, &JT[JumpTableIndices[NextHotJumpTableIndex]]),
-        JTInDiffSection, *MJTI);
-  }
+      JTInDiffSection);
+
+  const int NumColdJumpTables = JT.size() - NumHotJumpTables;
+  assert(NumColdJumpTables >= 0 && "Invalid number of cold jump tables.");
+
+  // Reverse iterating cold jump table indices to emit in the original order.
+  emitJumpTableImpl(
+      *MJTI,
+      llvm::make_range(JumpTableIndices.rbegin(),
+                       JumpTableIndices.rbegin() + NumColdJumpTables),
+      JTInDiffSection);
 
   return;
 }
 
-void AsmPrinter::emitJumpTables(ArrayRef<unsigned> JumpTableIndices,
-                                MCSection *JumpTableSection,
-                                bool JTInDiffSection,
-                                const MachineJumpTableInfo &MJTI) {
+template <typename Iterator>
+void AsmPrinter::emitJumpTableImpl(
+    const MachineJumpTableInfo &MJTI,
+    const llvm::iterator_range<Iterator> &JumpTableIndices,
+    bool JTInDiffSection) {
   if (JumpTableIndices.empty())
     return;
 
+  const TargetLoweringObjectFile &TLOF = getObjFileLowering();
+  const Function &F = MF->getFunction();
+  const std::vector<MachineJumpTableEntry> &JT = MJTI.getJumpTables();
+  MCSection *JumpTableSection = nullptr;
+  if (TM.Options.EnableStaticDataPartitioning) {
+    JumpTableSection =
+        TLOF.getSectionForJumpTable(F, TM, &JT[*JumpTableIndices.begin()]);
+  } else {
+    JumpTableSection = TLOF.getSectionForJumpTable(F, TM);
+  }
+
   const DataLayout &DL = MF->getDataLayout();
   if (JTInDiffSection) {
     OutStreamer->switchSection(JumpTableSection);
@@ -2944,10 +2955,8 @@ void AsmPrinter::emitJumpTables(ArrayRef<unsigned> JumpTableIndices,
   if (!JTInDiffSection)
     OutStreamer->emitDataRegion(MCDR_DataRegionJT32);
 
-  const auto &JT = MJTI.getJumpTables();
-  for (unsigned Index = 0, e = JumpTableIndices.size(); Index != e; ++Index) {
-    ArrayRef<MachineBasicBlock *> JTBBs =
-        JT[JumpTableIndices[Index]].MBBs;
+  for (const unsigned JumpTableIndex : JumpTableIndices) {
+    ArrayRef<MachineBasicBlock *> JTBBs = JT[JumpTableIndex].MBBs;
 
     // If this jump table was deleted, ignore it.
     if (JTBBs.empty())
@@ -2959,8 +2968,8 @@ void AsmPrinter::emitJumpTables(ArrayRef<unsigned> JumpTableIndices,
         MAI->doesSetDirectiveSuppressReloc()) {
       SmallPtrSet<const MachineBasicBlock *, 16> EmittedSets;
       const TargetLowering *TLI = MF->getSubtarget().getTargetLowering();
-      const MCExpr *Base = TLI->getPICJumpTableRelocBaseExpr(
-          MF, JumpTableIndices[Index], OutContext);
+      const MCExpr *Base =
+          TLI->getPICJumpTableRelocBaseExpr(MF, JumpTableIndex, OutContext);
       for (const MachineBasicBlock *MBB : JTBBs) {
         if (!EmittedSets.insert(MBB).second)
           continue;
@@ -2969,7 +2978,7 @@ void AsmPrinter::emitJumpTables(ArrayRef<unsigned> JumpTableIndices,
         const MCExpr *LHS =
             MCSymbolRefExpr::create(MBB->getSymbol(), OutContext);
         OutStreamer->emitAssignment(
-            GetJTSetSymbol(JumpTableIndices[Index], MBB->getNumber()),
+            GetJTSetSymbol(JumpTableIndex, MBB->getNumber()),
             MCBinaryExpr::createSub(LHS, Base, OutContext));
       }
     }
@@ -2982,15 +2991,15 @@ void AsmPrinter::emitJumpTables(ArrayRef<unsigned> JumpTableIndices,
       // FIXME: This doesn't have to have any specific name, just any randomly
       // named and numbered local label started with 'l' would work.  Simplify
       // GetJTISymbol.
-      OutStreamer->emitLabel(GetJTISymbol(JumpTableIndices[Index], true));
+      OutStreamer->emitLabel(GetJTISymbol(JumpTableIndex, true));
 
-    MCSymbol *JTISymbol = GetJTISymbol(JumpTableIndices[Index]);
+    MCSymbol *JTISymbol = GetJTISymbol(JumpTableIndex);
     OutStreamer->emitLabel(JTISymbol);
 
     // Defer MCAssembler based constant folding due to a performance issue. The
     // label differences will be evaluated at write time.
     for (const MachineBasicBlock *MBB : JTBBs)
-      emitJumpTableEntry(MJTI, MBB, JumpTableIndices[Index]);
+      emitJumpTableEntry(MJTI, MBB, JumpTableIndex);
   }
 
   if (EmitJumpTableSizesSection)

>From 027ae5695a6a08604f398eefdb8a66c77981a151 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 27 Jan 2025 14:21:32 -0800
Subject: [PATCH 15/17] resolve review feedback

---
 llvm/include/llvm/CodeGen/AsmPrinter.h     |  3 +-
 llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp | 45 +++++-----------------
 2 files changed, 11 insertions(+), 37 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/AsmPrinter.h b/llvm/include/llvm/CodeGen/AsmPrinter.h
index f5a852fdbd7ac3..3da63af5ba5716 100644
--- a/llvm/include/llvm/CodeGen/AsmPrinter.h
+++ b/llvm/include/llvm/CodeGen/AsmPrinter.h
@@ -893,9 +893,8 @@ class AsmPrinter : public MachineFunctionPass {
   // Internal Implementation Details
   //===------------------------------------------------------------------===//
 
-  template <typename Iterator>
   void emitJumpTableImpl(const MachineJumpTableInfo &MJTI,
-                         const llvm::iterator_range<Iterator> &JumpTableIndices,
+                         ArrayRef<unsigned> JumpTableIndices,
                          bool JTInDiffSection);
   void emitJumpTableEntry(const MachineJumpTableInfo &MJTI,
                           const MachineBasicBlock *MBB, unsigned uid) const;
diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index d29dd4a2d13261..0ca1484756142f 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -2868,59 +2868,34 @@ void AsmPrinter::emitJumpTableInfo() {
           MJTI->getEntryKind() == MachineJumpTableInfo::EK_LabelDifference64,
       F);
 
-  std::vector<unsigned> JumpTableIndices;
   if (!TM.Options.EnableStaticDataPartitioning) {
+    SmallVector<unsigned> JumpTableIndices;
     for (unsigned JTI = 0, JTSize = JT.size(); JTI < JTSize; ++JTI)
       JumpTableIndices.push_back(JTI);
-    emitJumpTableImpl(
-        *MJTI,
-        llvm::make_range(JumpTableIndices.begin(), JumpTableIndices.end()),
-        JTInDiffSection);
+    emitJumpTableImpl(*MJTI, JumpTableIndices, JTInDiffSection);
     return;
   }
 
+  SmallVector<unsigned> HotJumpTableIndices, ColdJumpTableIndices;
   // When static data partitioning is enabled, collect jump table entries that
   // go into the same section together to reduce the amount of section switch
   // statements.
-  //
-  // Iterate all jump tables, put hot jump table indices towards the beginning
-  // of the vector, and cold jump table indices towards the end. Meanwhile
-  // retain the relative orders of original jump tables.
-  int NumHotJumpTables = 0, NextColdJumpTableIndex = JT.size() - 1;
-  JumpTableIndices.resize(JT.size());
   for (unsigned JTI = 0, JTSize = JT.size(); JTI < JTSize; ++JTI) {
     if (JT[JTI].Hotness == MachineFunctionDataHotness::Cold) {
-      JumpTableIndices[NextColdJumpTableIndex--] = JTI;
+      ColdJumpTableIndices.push_back(JTI);
     } else {
-      JumpTableIndices[NumHotJumpTables++] = JTI;
+      HotJumpTableIndices.push_back(JTI);
     }
   }
 
-  emitJumpTableImpl(
-      *MJTI,
-      llvm::make_range(JumpTableIndices.begin(),
-                       JumpTableIndices.begin() + NumHotJumpTables),
-
-      JTInDiffSection);
-
-  const int NumColdJumpTables = JT.size() - NumHotJumpTables;
-  assert(NumColdJumpTables >= 0 && "Invalid number of cold jump tables.");
-
-  // Reverse iterating cold jump table indices to emit in the original order.
-  emitJumpTableImpl(
-      *MJTI,
-      llvm::make_range(JumpTableIndices.rbegin(),
-                       JumpTableIndices.rbegin() + NumColdJumpTables),
-      JTInDiffSection);
-
+  emitJumpTableImpl(*MJTI, HotJumpTableIndices, JTInDiffSection);
+  emitJumpTableImpl(*MJTI, ColdJumpTableIndices, JTInDiffSection);
   return;
 }
 
-template <typename Iterator>
-void AsmPrinter::emitJumpTableImpl(
-    const MachineJumpTableInfo &MJTI,
-    const llvm::iterator_range<Iterator> &JumpTableIndices,
-    bool JTInDiffSection) {
+void AsmPrinter::emitJumpTableImpl(const MachineJumpTableInfo &MJTI,
+                                   ArrayRef<unsigned> JumpTableIndices,
+                                   bool JTInDiffSection) {
   if (JumpTableIndices.empty())
     return;
 

>From fd0034aa40b216287fef8689b98aff542d8a2969 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 27 Jan 2025 14:39:55 -0800
Subject: [PATCH 16/17] [NFCI]Refactor AsmPrinter around jump table emission

---
 llvm/include/llvm/CodeGen/AsmPrinter.h     |  3 ++
 llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp | 56 +++++++++++++++-------
 2 files changed, 41 insertions(+), 18 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/AsmPrinter.h b/llvm/include/llvm/CodeGen/AsmPrinter.h
index 5291369b3b9f1d..3da63af5ba5716 100644
--- a/llvm/include/llvm/CodeGen/AsmPrinter.h
+++ b/llvm/include/llvm/CodeGen/AsmPrinter.h
@@ -893,6 +893,9 @@ class AsmPrinter : public MachineFunctionPass {
   // Internal Implementation Details
   //===------------------------------------------------------------------===//
 
+  void emitJumpTableImpl(const MachineJumpTableInfo &MJTI,
+                         ArrayRef<unsigned> JumpTableIndices,
+                         bool JTInDiffSection);
   void emitJumpTableEntry(const MachineJumpTableInfo &MJTI,
                           const MachineBasicBlock *MBB, unsigned uid) const;
 
diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index b2a4721f37b268..c21915673f643d 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -2868,42 +2868,62 @@ void AsmPrinter::emitJumpTableInfo() {
       MJTI->getEntryKind() == MachineJumpTableInfo::EK_LabelDifference32 ||
           MJTI->getEntryKind() == MachineJumpTableInfo::EK_LabelDifference64,
       F);
+
+  SmallVector<unsigned> JumpTableIndices;
+  for (unsigned JTI = 0, JTSize = JT.size(); JTI < JTSize; ++JTI) {
+    JumpTableIndices.push_back(JTI);
+  }
+  emitJumpTableImpl(*MJTI, JumpTableIndices, JTInDiffSection);
+}
+
+void AsmPrinter::emitJumpTableImpl(const MachineJumpTableInfo &MJTI,
+                                   ArrayRef<unsigned> JumpTableIndices,
+                                   bool JTInDiffSection) {
+  if (JumpTableIndices.empty())
+    return;
+
+  const TargetLoweringObjectFile &TLOF = getObjFileLowering();
+  const Function &F = MF->getFunction();
+  const std::vector<MachineJumpTableEntry> &JT = MJTI.getJumpTables();
+  MCSection *JumpTableSection = TLOF.getSectionForJumpTable(F, TM);
+
+  const DataLayout &DL = MF->getDataLayout();
   if (JTInDiffSection) {
-    // Drop it in the readonly section.
-    MCSection *ReadOnlySection = TLOF.getSectionForJumpTable(F, TM);
-    OutStreamer->switchSection(ReadOnlySection);
+    OutStreamer->switchSection(JumpTableSection);
   }
 
-  emitAlignment(Align(MJTI->getEntryAlignment(DL)));
+  emitAlignment(Align(MJTI.getEntryAlignment(MF->getDataLayout())));
 
   // Jump tables in code sections are marked with a data_region directive
   // where that's supported.
   if (!JTInDiffSection)
     OutStreamer->emitDataRegion(MCDR_DataRegionJT32);
 
-  for (unsigned JTI = 0, e = JT.size(); JTI != e; ++JTI) {
-    const std::vector<MachineBasicBlock*> &JTBBs = JT[JTI].MBBs;
+  for (const unsigned JumpTableIndex : JumpTableIndices) {
+    ArrayRef<MachineBasicBlock *> JTBBs = JT[JumpTableIndex].MBBs;
 
     // If this jump table was deleted, ignore it.
-    if (JTBBs.empty()) continue;
+    if (JTBBs.empty())
+      continue;
 
     // For the EK_LabelDifference32 entry, if using .set avoids a relocation,
     /// emit a .set directive for each unique entry.
-    if (MJTI->getEntryKind() == MachineJumpTableInfo::EK_LabelDifference32 &&
+    if (MJTI.getEntryKind() == MachineJumpTableInfo::EK_LabelDifference32 &&
         MAI->doesSetDirectiveSuppressReloc()) {
-      SmallPtrSet<const MachineBasicBlock*, 16> EmittedSets;
+      SmallPtrSet<const MachineBasicBlock *, 16> EmittedSets;
       const TargetLowering *TLI = MF->getSubtarget().getTargetLowering();
-      const MCExpr *Base = TLI->getPICJumpTableRelocBaseExpr(MF,JTI,OutContext);
+      const MCExpr *Base =
+          TLI->getPICJumpTableRelocBaseExpr(MF, JumpTableIndex, OutContext);
       for (const MachineBasicBlock *MBB : JTBBs) {
         if (!EmittedSets.insert(MBB).second)
           continue;
 
         // .set LJTSet, LBB32-base
         const MCExpr *LHS =
-          MCSymbolRefExpr::create(MBB->getSymbol(), OutContext);
-        OutStreamer->emitAssignment(GetJTSetSymbol(JTI, MBB->getNumber()),
-                                    MCBinaryExpr::createSub(LHS, Base,
-                                                            OutContext));
+            MCSymbolRefExpr::create(MBB->getSymbol(), OutContext);
+        OutStreamer->emitAssignment(
+            GetJTSetSymbol(JumpTableIndex, MBB->getNumber()),
+            MCBinaryExpr::createSub(LHS, Base, OutContext));
       }
     }
 
@@ -2915,19 +2935,19 @@ void AsmPrinter::emitJumpTableInfo() {
       // FIXME: This doesn't have to have any specific name, just any randomly
       // named and numbered local label started with 'l' would work.  Simplify
       // GetJTISymbol.
-      OutStreamer->emitLabel(GetJTISymbol(JTI, true));
+      OutStreamer->emitLabel(GetJTISymbol(JumpTableIndex, true));
 
-    MCSymbol* JTISymbol = GetJTISymbol(JTI);
+    MCSymbol *JTISymbol = GetJTISymbol(JumpTableIndex);
     OutStreamer->emitLabel(JTISymbol);
 
     // Defer MCAssembler based constant folding due to a performance issue. The
     // label differences will be evaluated at write time.
     for (const MachineBasicBlock *MBB : JTBBs)
-      emitJumpTableEntry(*MJTI, MBB, JTI);
+      emitJumpTableEntry(MJTI, MBB, JumpTableIndex);
   }
 
   if (EmitJumpTableSizesSection)
-    emitJumpTableSizesSection(*MJTI, F);
+    emitJumpTableSizesSection(MJTI, MF->getFunction());
 
   if (!JTInDiffSection)
     OutStreamer->emitDataRegion(MCDR_DataRegionEnd);

>From b919de9be0348ed4e2c79c436e0e4937c4b15e77 Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Tue, 28 Jan 2025 12:58:16 -0800
Subject: [PATCH 17/17] resolve comments

---
 llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index 47467de53f80af..44b10c3ef99726 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -2869,10 +2869,8 @@ void AsmPrinter::emitJumpTableInfo() {
       F);
 
   if (!TM.Options.EnableStaticDataPartitioning) {
-    SmallVector<unsigned> JumpTableIndices;
-    for (unsigned JTI = 0, JTSize = JT.size(); JTI < JTSize; ++JTI)
-      JumpTableIndices.push_back(JTI);
-    emitJumpTableImpl(*MJTI, JumpTableIndices, JTInDiffSection);
+    emitJumpTableImpl(*MJTI, llvm::to_vector(llvm::seq<unsigned>(JT.size())),
+                      JTInDiffSection);
     return;
   }
 
@@ -2904,7 +2902,7 @@ void AsmPrinter::emitJumpTableImpl(const MachineJumpTableInfo &MJTI,
   MCSection *JumpTableSection = nullptr;
   if (TM.Options.EnableStaticDataPartitioning) {
     JumpTableSection =
-        TLOF.getSectionForJumpTable(F, TM, &JT[*JumpTableIndices.begin()]);
+        TLOF.getSectionForJumpTable(F, TM, &JT[JumpTableIndices.front()]);
   } else {
     JumpTableSection = TLOF.getSectionForJumpTable(F, TM);
   }



More information about the llvm-commits mailing list