[llvm] [RFC][BPF] Support Jump Table (PR #133856)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 17 11:26:29 PDT 2025


https://github.com/yonghong-song updated https://github.com/llvm/llvm-project/pull/133856

>From db8616817248c054eabd96e1ea0d7a28661023ab Mon Sep 17 00:00:00 2001
From: Yonghong Song <yonghong.song at linux.dev>
Date: Mon, 31 Mar 2025 21:25:26 -0700
Subject: [PATCH] [RFC][BPF] Support Jump Table

NOTE: We probably need cpu v5 or other flags to enable this feature.
We can add it later when necessary.

This patch adds jump table support. A new insn 'gotox <reg>' is
added to allow goto through a register. The register represents
the address in the current section. The function is a concrete
example with bpf selftest progs/user_ringbuf_success.c.

Compilation command line to generate .s file:
=============================================
clang  -g -Wall -Werror -D__TARGET_ARCH_x86 -mlittle-endian \
    -I/home/yhs/work/bpf-next/tools/testing/selftests/bpf/tools/include \
    -I/home/yhs/work/bpf-next/tools/testing/selftests/bpf \
    -I/home/yhs/work/bpf-next/tools/include/uapi \
    -I/home/yhs/work/bpf-next/tools/testing/selftests/usr/include -std=gnu11 \
    -fno-strict-aliasing -Wno-compare-distinct-pointer-types \
    -idirafter /home/yhs/work/llvm-project/llvm/build.21/Release/lib/clang/21/include \
    -idirafter /usr/local/include -idirafter /usr/include \
    -DENABLE_ATOMICS_TESTS   -O2 -S progs/user_ringbuf_success.c \
    -o /home/yhs/work/bpf-next/tools/testing/selftests/bpf/user_ringbuf_success.bpf.o.s \
    --target=bpf -mcpu=v3

The related assembly:
  read_protocol_msg:
        ...
        r3 <<= 3
        r1 = .LJTI1_0 ll
        r1 += r3
        r1 = *(u64 *)(r1 + 0)
        gotox r1
  LBB1_4:
        r1 = *(u64 *)(r0 + 8)
        goto LBB1_5
  LBB1_7:
        r1 = *(u64 *)(r0 + 8)
        goto LBB1_8
  LBB1_9:
        w1 = *(u32 *)(r0 + 8)
        r1 <<= 32
        r1 s>>= 32
        r2 = kern_mutated ll
        r3 = *(u64 *)(r2 + 0)
        r3 *= r1
        *(u64 *)(r2 + 0) = r3
        goto LBB1_11
  LBB1_6:
        w1 = *(u32 *)(r0 + 8)
        r1 <<= 32
        r1 s>>= 32
  LBB1_5:
  ...
        .section        .rodata,"a", at progbits
        .p2align        3, 0x0
  .LJTI1_0:
        .quad   LBB1_4
        .quad   LBB1_6
        .quad   LBB1_7
        .quad   LBB1_9
  ...
  publish_next_kern_msg:
        ...
        r6 <<= 3
        r1 = .LJTI6_0 ll
        r1 += r6
        r1 = *(u64 *)(r1 + 0)
        gotox r1
  LBB6_3:
        ...
  LBB6_5:
        ...
  LBB6_6:
        ...
  LBB6_4:
        ...
        .section        .rodata,"a", at progbits
        .p2align        3, 0x0
.LJTI6_0:
        .quad   LBB6_3
        .quad   LBB6_4
        .quad   LBB6_5
        .quad   LBB6_6

You can see in the above .LJTI1_0 and .LJTI6_0 are actually jump table targets
and these two jump tables are used in insns so they can get proper jump
table target with gotox insn.

Now let us look at sections in .o file
=======================================
For example,
  [ 6] .rodata           PROGBITS        0000000000000000 000740 0000d6 00   A  0   0  8
  [ 7] .rel.rodata       REL             0000000000000000 003860 000080 10   I 39   6  8
  [ 8] .llvm_jump_table_sizes LLVM_JT_SIZES 0000000000000000 000816 000010 00      0   0  1
  [ 9] .rel.llvm_jump_table_sizes REL    0000000000000000 0038e0 000010 10   I 39   8  8
  ...
  [14] .llvm_jump_table_sizes LLVM_JT_SIZES 0000000000000000 000958 000010 00      0   0  1
  [15] .rel.llvm_jump_table_sizes REL    0000000000000000 003970 000010 10   I 39  14  8

With llvm-readelf dump section 8 and 14:
  $ llvm-readelf -x 8 user_ringbuf_success.bpf.o
  Hex dump of section '.llvm_jump_table_sizes':
  0x00000000 00000000 00000000 04000000 00000000 ................
  $ llvm-readelf -x 14 user_ringbuf_success.bpf.o
  Hex dump of section '.llvm_jump_table_sizes':
  0x00000000 20000000 00000000 04000000 00000000  ...............
You can see. There are two jump tables:
  jump table 1: offset 0, size 4 (4 labels)
  jump table 2: offset 0x20, size 4 (4 labels)

Check sections 9 and 15, we can find the corresponding section:
  Relocation section '.rel.llvm_jump_table_sizes' at offset 0x38e0 contains 1 entries:
      Offset             Info             Type               Symbol's Value  Symbol's Name
  0000000000000000  0000000a00000002 R_BPF_64_ABS64         0000000000000000 .rodata
  Relocation section '.rel.llvm_jump_table_sizes' at offset 0x3970 contains 1 entries:
      Offset             Info             Type               Symbol's Value  Symbol's Name
  0000000000000000  0000000a00000002 R_BPF_64_ABS64         0000000000000000 .rodata
and confirmed that the relocation is against '.rodata'.

Dump .rodata section:
  0x00000000 a8000000 00000000 10010000 00000000 ................
  0x00000010 b8000000 00000000 c8000000 00000000 ................
  0x00000020 28040000 00000000 00050000 00000000 (...............
  0x00000030 70040000 00000000 b8040000 00000000 p...............
  0x00000040 44726169 6e207265 7475726e 65643a20 Drain returned:

So we can get two jump tables:
  .rodata offset 0, # of lables 4:
  0x00000000 a8000000 00000000 10010000 00000000 ................
  0x00000010 b8000000 00000000 c8000000 00000000 ................
  .rodata offset 0x200, # of lables 4:
  0x00000020 28040000 00000000 00050000 00000000 (...............
  0x00000030 70040000 00000000 b8040000 00000000 p...............

This way, you just need to scan related code section. As long as it
matches one of jump tables (.rodata relocation, offset also matching),
you do not need to care about gotox at all in libbpf.

An option -bpf-min-jump-table-entries is implemented to control the minimum
number of entries to use a jump table on BPF. The default value 4, but it
can be changed with the following clang option
  clang ... -mllvm -bpf-min-jump-table-entries=6
where the number of jump table cases needs to be >= 6 in order to
use jump table.
---
 llvm/include/llvm/CodeGen/AsmPrinter.h        |  2 +
 llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp    |  2 +-
 .../lib/Target/BPF/AsmParser/BPFAsmParser.cpp |  1 +
 llvm/lib/Target/BPF/BPFAsmPrinter.cpp         |  1 +
 llvm/lib/Target/BPF/BPFISelLowering.cpp       | 36 +++++++++++++++-
 llvm/lib/Target/BPF/BPFISelLowering.h         |  2 +
 llvm/lib/Target/BPF/BPFInstrInfo.cpp          | 41 +++++++++++++++++++
 llvm/lib/Target/BPF/BPFInstrInfo.h            |  3 ++
 llvm/lib/Target/BPF/BPFInstrInfo.td           | 27 ++++++++++++
 llvm/lib/Target/BPF/BPFMCInstLower.cpp        |  3 ++
 10 files changed, 115 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/AsmPrinter.h b/llvm/include/llvm/CodeGen/AsmPrinter.h
index 6ad54fcd6d0e5..8cf00cc370821 100644
--- a/llvm/include/llvm/CodeGen/AsmPrinter.h
+++ b/llvm/include/llvm/CodeGen/AsmPrinter.h
@@ -26,6 +26,7 @@
 #include "llvm/CodeGen/StackMaps.h"
 #include "llvm/DebugInfo/CodeView/CodeView.h"
 #include "llvm/IR/InlineAsm.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/ErrorHandling.h"
 #include <cstdint>
@@ -34,6 +35,7 @@
 #include <vector>
 
 namespace llvm {
+extern cl::opt<bool> EmitJumpTableSizesSection;
 
 class AddrLabelMap;
 class AsmPrinterHandler;
diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index a2c3b50b24670..6a93569e52b28 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -168,7 +168,7 @@ static cl::opt<bool> BBAddrMapSkipEmitBBEntries(
              "unnecessary for some PGOAnalysisMap features."),
     cl::Hidden, cl::init(false));
 
-static cl::opt<bool> EmitJumpTableSizesSection(
+cl::opt<bool> llvm::EmitJumpTableSizesSection(
     "emit-jump-table-sizes-section",
     cl::desc("Emit a section containing jump table addresses and sizes"),
     cl::Hidden, cl::init(false));
diff --git a/llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp b/llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp
index 7d1819134d162..3a8f559be942c 100644
--- a/llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp
+++ b/llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp
@@ -232,6 +232,7 @@ struct BPFOperand : public MCParsedAsmOperand {
         .Case("callx", true)
         .Case("goto", true)
         .Case("gotol", true)
+        .Case("gotox", true)
         .Case("may_goto", true)
         .Case("*", true)
         .Case("exit", true)
diff --git a/llvm/lib/Target/BPF/BPFAsmPrinter.cpp b/llvm/lib/Target/BPF/BPFAsmPrinter.cpp
index 5dd71cc91427a..e2856bab354c8 100644
--- a/llvm/lib/Target/BPF/BPFAsmPrinter.cpp
+++ b/llvm/lib/Target/BPF/BPFAsmPrinter.cpp
@@ -57,6 +57,7 @@ class BPFAsmPrinter : public AsmPrinter {
 } // namespace
 
 bool BPFAsmPrinter::doInitialization(Module &M) {
+  EmitJumpTableSizesSection = true;
   AsmPrinter::doInitialization(M);
 
   // Only emit BTF when debuginfo available.
diff --git a/llvm/lib/Target/BPF/BPFISelLowering.cpp b/llvm/lib/Target/BPF/BPFISelLowering.cpp
index f4f414d192df0..154db34be786a 100644
--- a/llvm/lib/Target/BPF/BPFISelLowering.cpp
+++ b/llvm/lib/Target/BPF/BPFISelLowering.cpp
@@ -38,6 +38,10 @@ static cl::opt<bool> BPFExpandMemcpyInOrder("bpf-expand-memcpy-in-order",
   cl::Hidden, cl::init(false),
   cl::desc("Expand memcpy into load/store pairs in order"));
 
+static cl::opt<unsigned> BPFMinimumJumpTableEntries(
+    "bpf-min-jump-table-entries", cl::init(4), cl::Hidden,
+    cl::desc("Set minimum number of entries to use a jump table on BPF"));
+
 static void fail(const SDLoc &DL, SelectionDAG &DAG, const Twine &Msg,
                  SDValue Val = {}) {
   std::string Str;
@@ -67,12 +71,13 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,
 
   setOperationAction(ISD::BR_CC, MVT::i64, Custom);
   setOperationAction(ISD::BR_JT, MVT::Other, Expand);
-  setOperationAction(ISD::BRIND, MVT::Other, Expand);
   setOperationAction(ISD::BRCOND, MVT::Other, Expand);
 
   setOperationAction(ISD::TRAP, MVT::Other, Custom);
 
-  setOperationAction({ISD::GlobalAddress, ISD::ConstantPool}, MVT::i64, Custom);
+  setOperationAction({ISD::GlobalAddress, ISD::ConstantPool, ISD::JumpTable,
+                      ISD::BlockAddress},
+                     MVT::i64, Custom);
 
   setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom);
   setOperationAction(ISD::STACKSAVE, MVT::Other, Expand);
@@ -159,6 +164,7 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,
 
   setBooleanContents(ZeroOrOneBooleanContent);
   setMaxAtomicSizeInBitsSupported(64);
+  setMinimumJumpTableEntries(BPFMinimumJumpTableEntries);
 
   // Function alignments
   setMinFunctionAlignment(Align(8));
@@ -316,10 +322,14 @@ SDValue BPFTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     report_fatal_error("unimplemented opcode: " + Twine(Op.getOpcode()));
   case ISD::BR_CC:
     return LowerBR_CC(Op, DAG);
+  case ISD::JumpTable:
+    return LowerJumpTable(Op, DAG);
   case ISD::GlobalAddress:
     return LowerGlobalAddress(Op, DAG);
   case ISD::ConstantPool:
     return LowerConstantPool(Op, DAG);
+  case ISD::BlockAddress:
+    return LowerBlockAddress(Op, DAG);
   case ISD::SELECT_CC:
     return LowerSELECT_CC(Op, DAG);
   case ISD::SDIV:
@@ -780,6 +790,11 @@ SDValue BPFTargetLowering::LowerTRAP(SDValue Op, SelectionDAG &DAG) const {
   return LowerCall(CLI, InVals);
 }
 
+SDValue BPFTargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const {
+  JumpTableSDNode *N = cast<JumpTableSDNode>(Op);
+  return getAddr(N, DAG);
+}
+
 const char *BPFTargetLowering::getTargetNodeName(unsigned Opcode) const {
   switch ((BPFISD::NodeType)Opcode) {
   case BPFISD::FIRST_NUMBER:
@@ -811,6 +826,17 @@ static SDValue getTargetNode(ConstantPoolSDNode *N, const SDLoc &DL, EVT Ty,
                                    N->getOffset(), Flags);
 }
 
+static SDValue getTargetNode(BlockAddressSDNode *N, const SDLoc &DL, EVT Ty,
+                             SelectionDAG &DAG, unsigned Flags) {
+  return DAG.getTargetBlockAddress(N->getBlockAddress(), Ty, N->getOffset(),
+                                   Flags);
+}
+
+static SDValue getTargetNode(JumpTableSDNode *N, const SDLoc &DL, EVT Ty,
+                             SelectionDAG &DAG, unsigned Flags) {
+  return DAG.getTargetJumpTable(N->getIndex(), Ty, Flags);
+}
+
 template <class NodeTy>
 SDValue BPFTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG,
                                    unsigned Flags) const {
@@ -837,6 +863,12 @@ SDValue BPFTargetLowering::LowerConstantPool(SDValue Op,
   return getAddr(N, DAG);
 }
 
+SDValue BPFTargetLowering::LowerBlockAddress(SDValue Op,
+                                             SelectionDAG &DAG) const {
+  BlockAddressSDNode *N = cast<BlockAddressSDNode>(Op);
+  return getAddr(N, DAG);
+}
+
 unsigned
 BPFTargetLowering::EmitSubregExt(MachineInstr &MI, MachineBasicBlock *BB,
                                  unsigned Reg, bool isSigned) const {
diff --git a/llvm/lib/Target/BPF/BPFISelLowering.h b/llvm/lib/Target/BPF/BPFISelLowering.h
index 23cbce7094e6b..acb8f27c647d7 100644
--- a/llvm/lib/Target/BPF/BPFISelLowering.h
+++ b/llvm/lib/Target/BPF/BPFISelLowering.h
@@ -81,6 +81,8 @@ class BPFTargetLowering : public TargetLowering {
   SDValue LowerConstantPool(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerTRAP(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
 
   template <class NodeTy>
   SDValue getAddr(NodeTy *N, SelectionDAG &DAG, unsigned Flags = 0) const;
diff --git a/llvm/lib/Target/BPF/BPFInstrInfo.cpp b/llvm/lib/Target/BPF/BPFInstrInfo.cpp
index 70bc163615f61..78626c39e80f7 100644
--- a/llvm/lib/Target/BPF/BPFInstrInfo.cpp
+++ b/llvm/lib/Target/BPF/BPFInstrInfo.cpp
@@ -181,6 +181,10 @@ bool BPFInstrInfo::analyzeBranch(MachineBasicBlock &MBB,
     if (!isUnpredicatedTerminator(*I))
       break;
 
+    // If a JX insn, we're done.
+    if (I->getOpcode() == BPF::JX)
+      break;
+
     // A terminator that isn't a branch can't easily be handled
     // by this analysis.
     if (!I->isBranch())
@@ -259,3 +263,40 @@ unsigned BPFInstrInfo::removeBranch(MachineBasicBlock &MBB,
 
   return Count;
 }
+
+int BPFInstrInfo::getJumpTableIndex(const MachineInstr &MI) const {
+  // The pattern looks like:
+  // %0 = LD_imm64 %jump-table.0   ; load jump-table address
+  // %1 = ADD_rr %0, $another_reg  ; address + offset
+  // %2 = LDD %1, 0                ; load the actual label
+  // JX %2
+  const MachineFunction &MF = *MI.getParent()->getParent();
+  const MachineRegisterInfo &MRI = MF.getRegInfo();
+
+  Register Reg = MI.getOperand(0).getReg();
+  if (!Reg.isVirtual())
+    return -1;
+  MachineInstr *Ldd = MRI.getUniqueVRegDef(Reg);
+  if (Ldd == nullptr || Ldd->getOpcode() != BPF::LDD)
+    return -1;
+
+  Reg = Ldd->getOperand(1).getReg();
+  if (!Reg.isVirtual())
+    return -1;
+  MachineInstr *Add = MRI.getUniqueVRegDef(Reg);
+  if (Add == nullptr || Add->getOpcode() != BPF::ADD_rr)
+    return -1;
+
+  Reg = Add->getOperand(1).getReg();
+  if (!Reg.isVirtual())
+    return -1;
+  MachineInstr *LDimm64 = MRI.getUniqueVRegDef(Reg);
+  if (LDimm64 == nullptr || LDimm64->getOpcode() != BPF::LD_imm64)
+    return -1;
+
+  const MachineOperand &MO = LDimm64->getOperand(1);
+  if (!MO.isJTI())
+    return -1;
+
+  return MO.getIndex();
+}
diff --git a/llvm/lib/Target/BPF/BPFInstrInfo.h b/llvm/lib/Target/BPF/BPFInstrInfo.h
index d8bbad44e314e..d88e37975980a 100644
--- a/llvm/lib/Target/BPF/BPFInstrInfo.h
+++ b/llvm/lib/Target/BPF/BPFInstrInfo.h
@@ -58,6 +58,9 @@ class BPFInstrInfo : public BPFGenInstrInfo {
                         MachineBasicBlock *FBB, ArrayRef<MachineOperand> Cond,
                         const DebugLoc &DL,
                         int *BytesAdded = nullptr) const override;
+
+  int getJumpTableIndex(const MachineInstr &MI) const override;
+
 private:
   void expandMEMCPY(MachineBasicBlock::iterator) const;
 
diff --git a/llvm/lib/Target/BPF/BPFInstrInfo.td b/llvm/lib/Target/BPF/BPFInstrInfo.td
index b21f1a0eee3b0..c715bdb01866a 100644
--- a/llvm/lib/Target/BPF/BPFInstrInfo.td
+++ b/llvm/lib/Target/BPF/BPFInstrInfo.td
@@ -183,6 +183,15 @@ class TYPE_LD_ST<bits<3> mode, bits<2> size,
   let Inst{60-59} = size;
 }
 
+// For indirect jump
+class TYPE_IND_JMP<bits<4> op, bits<1> srctype,
+                   dag outs, dag ins, string asmstr, list<dag> pattern>
+  : InstBPF<outs, ins, asmstr, pattern> {
+
+  let Inst{63-60} = op;
+  let Inst{59} = srctype;
+}
+
 // jump instructions
 class JMP_RR<BPFJumpOp Opc, string OpcodeStr, PatLeaf Cond>
     : TYPE_ALU_JMP<Opc.Value, BPF_X.Value,
@@ -216,6 +225,18 @@ class JMP_RI<BPFJumpOp Opc, string OpcodeStr, PatLeaf Cond>
   let BPFClass = BPF_JMP;
 }
 
+class JMP_IND<BPFJumpOp Opc, string OpcodeStr, list<dag> Pattern>
+    : TYPE_ALU_JMP<Opc.Value, BPF_X.Value,
+                   (outs),
+                   (ins GPR:$dst),
+                   !strconcat(OpcodeStr, " $dst"),
+                   Pattern> {
+  bits<4> dst;
+
+  let Inst{51-48} = dst;
+  let BPFClass = BPF_JMP;
+}
+
 class JMP_JCOND<BPFJumpOp Opc, string OpcodeStr, list<dag> Pattern>
     : TYPE_ALU_JMP<Opc.Value, BPF_K.Value,
                    (outs),
@@ -281,6 +302,10 @@ defm JSLT : J<BPF_JSLT, "s<", BPF_CC_LT, BPF_CC_LT_32>;
 defm JSLE : J<BPF_JSLE, "s<=", BPF_CC_LE, BPF_CC_LE_32>;
 defm JSET : J<BPF_JSET, "&", NoCond, NoCond>;
 def JCOND : JMP_JCOND<BPF_JCOND, "may_goto", []>;
+
+let isIndirectBranch = 1 in {
+  def JX : JMP_IND<BPF_JA, "gotox", [(brind i64:$dst)]>;
+}
 }
 
 // ALU instructions
@@ -851,6 +876,8 @@ let usesCustomInserter = 1, isCodeGenOnly = 1 in {
 // load 64-bit global addr into register
 def : Pat<(BPFWrapper tglobaladdr:$in), (LD_imm64 tglobaladdr:$in)>;
 def : Pat<(BPFWrapper tconstpool:$in), (LD_imm64 tconstpool:$in)>;
+def : Pat<(BPFWrapper tblockaddress:$in), (LD_imm64 tblockaddress:$in)>;
+def : Pat<(BPFWrapper tjumptable:$in), (LD_imm64 tjumptable:$in)>;
 
 // 0xffffFFFF doesn't fit into simm32, optimize common case
 def : Pat<(i64 (and (i64 GPR:$src), 0xffffFFFF)),
diff --git a/llvm/lib/Target/BPF/BPFMCInstLower.cpp b/llvm/lib/Target/BPF/BPFMCInstLower.cpp
index 040a1fb750702..164d172c241c8 100644
--- a/llvm/lib/Target/BPF/BPFMCInstLower.cpp
+++ b/llvm/lib/Target/BPF/BPFMCInstLower.cpp
@@ -77,6 +77,9 @@ void BPFMCInstLower::Lower(const MachineInstr *MI, MCInst &OutMI) const {
     case MachineOperand::MO_ConstantPoolIndex:
       MCOp = LowerSymbolOperand(MO, Printer.GetCPISymbol(MO.getIndex()));
       break;
+    case MachineOperand::MO_JumpTableIndex:
+      MCOp = LowerSymbolOperand(MO, Printer.GetJTISymbol(MO.getIndex()));
+      break;
     }
 
     OutMI.addOperand(MCOp);



More information about the llvm-commits mailing list