If a jump table has entries at the end that are a result of __builtin_unreachable() targets, BOLT can confuse them with function pointers. In such case, we should exclude these targets from the table as we risk incorrectly updating the function pointers. It is safe to exclude them as branching on such targets is considered an undefined behavior.

>From 9db2dc32cc4daa03fda9a919e847c5c4a9711905 Mon Sep 17 00:00:00 2001
From: Maksim Panchenko <maks at meta.com>
Date: Fri, 29 Mar 2024 21:05:39 -0700
Subject: [PATCH] [BOLT] Fix handling of trailing entries in jump tables

 bolt/lib/Core/BinaryContext.cpp      |  26 ++++-
 bolt/test/runtime/X86/jt-confusion.s | 164 +++++++++++++++++++++++++++
 2 files changed, 185 insertions(+), 5 deletions(-)
 create mode 100644 bolt/test/runtime/X86/jt-confusion.s

diff --git a/bolt/lib/Core/BinaryContext.cpp b/bolt/lib/Core/BinaryContext.cpp
index 47eae964e816c5..7c2d8c52287be1 100644
--- a/bolt/lib/Core/BinaryContext.cpp
+++ b/bolt/lib/Core/BinaryContext.cpp
@@ -555,6 +555,9 @@ bool BinaryContext::analyzeJumpTable(const uint64_t Address,
                                      const uint64_t NextJTAddress,
                                      JumpTable::AddressesType *EntriesAsAddress,
                                      bool *HasEntryInFragment) const {
+  // Target address of __builtin_unreachable.
+  const uint64_t UnreachableAddress = BF.getAddress() + BF.getSize();
   // Is one of the targets __builtin_unreachable?
   bool HasUnreachable = false;
@@ -564,9 +567,15 @@ bool BinaryContext::analyzeJumpTable(const uint64_t Address,
   // Number of targets other than __builtin_unreachable.
   uint64_t NumRealEntries = 0;
-  auto addEntryAddress = [&](uint64_t EntryAddress) {
-    if (EntriesAsAddress)
-      EntriesAsAddress->emplace_back(EntryAddress);
+  // Size of the jump table without trailing __builtin_unreachable entries.
+  size_t TrimmedSize = 0;
+  auto addEntryAddress = [&](uint64_t EntryAddress, bool Unreachable = false) {
+    if (!EntriesAsAddress)
+      return;
+    EntriesAsAddress->emplace_back(EntryAddress);
+    if (!Unreachable)
+      TrimmedSize = EntriesAsAddress->size();
   ErrorOr<const BinarySection &> Section = getSectionForAddress(Address);
@@ -618,8 +627,8 @@ bool BinaryContext::analyzeJumpTable(const uint64_t Address,
             : *getPointerAtAddress(EntryAddress);
     // __builtin_unreachable() case.
-    if (Value == BF.getAddress() + BF.getSize()) {
-      addEntryAddress(Value);
+    if (Value == UnreachableAddress) {
+      addEntryAddress(Value, /*Unreachable*/ true);
       HasUnreachable = true;
       LLVM_DEBUG(dbgs() << formatv("OK: {0:x} __builtin_unreachable\n", Value));
@@ -673,6 +682,13 @@ bool BinaryContext::analyzeJumpTable(const uint64_t Address,
+  // Trim direct/normal jump table to exclude trailing unreachable entries that
+  // can collide with a function address.
+  if (Type == JumpTable::JTT_NORMAL && EntriesAsAddress &&
+      TrimmedSize != EntriesAsAddress->size() &&
+      getBinaryFunctionAtAddress(UnreachableAddress))
+    EntriesAsAddress->resize(TrimmedSize);
   // It's a jump table if the number of real entries is more than 1, or there's
   // one real entry and one or more special targets. If there are only multiple
   // special targets, then it's not a jump table.
diff --git a/bolt/test/runtime/X86/jt-confusion.s b/bolt/test/runtime/X86/jt-confusion.s
new file mode 100644
index 00000000000000..f15c83b35b6a44
--- /dev/null
+++ b/bolt/test/runtime/X86/jt-confusion.s
@@ -0,0 +1,164 @@
+# REQUIRES: system-linux
+# RUN: llvm-mc -filetype=obj -triple x86_64-unknown-unknown %s -o %t.o
+# RUN: llvm-strip --strip-unneeded %t.o
+# RUN: %clang %cflags -no-pie -nostartfiles -nostdlib -lc %t.o -o %t.exe -Wl,-q
+# RUN: llvm-bolt %t.exe -o %t.exe.bolt --relocs=1 --lite=0
+# RUN: %t.exe.bolt
+## Check that BOLT's jump table detection diffrentiates between
+## __builtin_unreachable() targets and function pointers.
+## The test case was built from the following two source files and
+## modiffied for standalone build. main became _start, etc.
+## $ $(CC) a.c -O1 -S -o a.s
+## $ $(CC) b.c -O0 -S -o b.s
+## a.c:
+## typedef int (*fptr)(int);
+## void check_fptr(fptr, int);
+## int foo(int a) {
+##   check_fptr(foo, 0);
+##   switch (a) {
+##   default:
+##     __builtin_unreachable();
+##   case 0:
+##     return 3;
+##   case 1:
+##     return 5;
+##   case 2:
+##     return 7;
+##   case 3:
+##     return 11;
+##   case 4:
+##     return 13;
+##   case 5:
+##     return 17;
+##   }
+##   return 0;
+## }
+## int main(int argc) {
+##   check_fptr(main, 1);
+##   return foo(argc);
+## }
+## const fptr funcs[2] = {foo, main};
+## b.c.:
+## typedef int (*fptr)(int);
+## extern const fptr funcs[2];
+## #define assert(C) { if (!(C)) (*(unsigned long long *)0) = 0; }
+## void check_fptr(fptr f, int i) {
+##   assert(f == funcs[i]);
+## }
+	.text
+	.globl	foo
+	.type	foo, @function
+	.cfi_startproc
+	pushq	%rbx
+	.cfi_def_cfa_offset 16
+	.cfi_offset 3, -16
+	movl	%edi, %ebx
+	movl	$0, %esi
+	movl	$foo, %edi
+	call	check_fptr
+	movl	%ebx, %ebx
+	jmp	*.L4(,%rbx,8)
+	movl	$5, %eax
+	jmp	.L1
+	movl	$7, %eax
+	jmp	.L1
+	movl	$11, %eax
+	jmp	.L1
+	movl	$13, %eax
+	jmp	.L1
+	movl	$17, %eax
+	jmp	.L1
+	movl	$3, %eax
+	popq	%rbx
+	.cfi_def_cfa_offset 8
+	ret
+	.cfi_endproc
+	.size	foo, .-foo
+	.globl	_start
+	.type	_start, @function
+	.cfi_startproc
+	pushq	%rbx
+	.cfi_def_cfa_offset 16
+	.cfi_offset 3, -16
+	movl	%edi, %ebx
+	movl	$1, %esi
+	movl	$_start, %edi
+	call	check_fptr
+	movl	$1, %edi
+	call	foo
+	popq	%rbx
+	.cfi_def_cfa_offset 8
+  callq exit at PLT
+	.cfi_endproc
+	.size	_start, .-_start
+	.globl	check_fptr
+	.type	check_fptr, @function
+	.cfi_startproc
+	pushq	%rbp
+	.cfi_def_cfa_offset 16
+	.cfi_offset 6, -16
+	movq	%rsp, %rbp
+	.cfi_def_cfa_register 6
+	movq	%rdi, -8(%rbp)
+	movl	%esi, -12(%rbp)
+	movl	-12(%rbp), %eax
+	cltq
+	movq	funcs(,%rax,8), %rax
+	cmpq	%rax, -8(%rbp)
+	je	.L33
+	movl	$0, %eax
+	movq	$0, (%rax)
+	nop
+	popq	%rbp
+	.cfi_def_cfa 7, 8
+	ret
+	.cfi_endproc
+	.section	.rodata
+	.align 8
+	.align 4
+	.quad	.L10
+	.quad	.L8
+	.quad	.L7
+	.quad	.L6
+	.quad	.L5
+	.quad	.L3
+	.globl	funcs
+	.type	funcs, @object
+	.size	funcs, 16
+	.quad	foo
+	.quad	_start

