[llvm] LowerTypeTests: Switch to emitting one inline asm call per jump table entry. (PR #136265)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 17 23:13:20 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Peter Collingbourne (pcc)

<details>
<summary>Changes</summary>

With the previous approach of emitting one inline asm call for all jump
table entries we would encounter SelectionDAG's limit on the number
of operands per node (65536) when the number of jump table entries
exceeded that number. Fix the problem by switching to one inline asm
per jump table entry so that each DAG node only needs one operand.


---
Full diff: https://github.com/llvm/llvm-project/pull/136265.diff


6 Files Affected:

- (modified) llvm/lib/Transforms/IPO/LowerTypeTests.cpp (+39-57) 
- (modified) llvm/test/Transforms/LowerTypeTests/aarch64-jumptable.ll (+2-1) 
- (modified) llvm/test/Transforms/LowerTypeTests/cfi-direct-call1.ll (+4-1) 
- (modified) llvm/test/Transforms/LowerTypeTests/function.ll (+22-14) 
- (modified) llvm/test/Transforms/LowerTypeTests/x86-jumptable.ll (+4-2) 
- (modified) llvm/test/Transforms/MergeFunc/cfi-thunk-merging.ll (+2-1) 


``````````diff
diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
index 7c01f8b560fea..653e7f0c50b6b 100644
--- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
+++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
@@ -441,10 +441,6 @@ class LowerTypeTestsModule {
   // Cache variable used by hasBranchTargetEnforcement().
   int HasBranchTargetEnforcement = -1;
 
-  // The jump table type we ended up deciding on. (Usually the same as
-  // Arch, except that 'arm' and 'thumb' are often interchangeable.)
-  Triple::ArchType JumpTableArch = Triple::UnknownArch;
-
   IntegerType *Int1Ty = Type::getInt1Ty(M.getContext());
   IntegerType *Int8Ty = Type::getInt8Ty(M.getContext());
   PointerType *PtrTy = PointerType::getUnqual(M.getContext());
@@ -525,11 +521,8 @@ class LowerTypeTestsModule {
   Triple::ArchType
   selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions);
   bool hasBranchTargetEnforcement();
-  unsigned getJumpTableEntrySize();
-  Type *getJumpTableEntryType();
-  void createJumpTableEntry(raw_ostream &AsmOS, raw_ostream &ConstraintOS,
-                            Triple::ArchType JumpTableArch,
-                            SmallVectorImpl<Value *> &AsmArgs, Function *Dest);
+  unsigned getJumpTableEntrySize(Triple::ArchType JumpTableArch);
+  InlineAsm *createJumpTableEntryAsm(Triple::ArchType JumpTableArch);
   void verifyTypeMDNode(GlobalObject *GO, MDNode *Type);
   void buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds,
                                  ArrayRef<GlobalTypeMember *> Functions);
@@ -548,7 +541,8 @@ class LowerTypeTestsModule {
   void findGlobalVariableUsersOf(Constant *C,
                                  SmallSetVector<GlobalVariable *, 8> &Out);
 
-  void createJumpTable(Function *F, ArrayRef<GlobalTypeMember *> Functions);
+  void createJumpTable(Function *F, ArrayRef<GlobalTypeMember *> Functions,
+                       Triple::ArchType JumpTableArch);
 
   /// replaceCfiUses - Go through the uses list for this definition
   /// and make each use point to "V" instead of "this" when the use is outside
@@ -1245,7 +1239,8 @@ bool LowerTypeTestsModule::hasBranchTargetEnforcement() {
   return HasBranchTargetEnforcement;
 }
 
-unsigned LowerTypeTestsModule::getJumpTableEntrySize() {
+unsigned
+LowerTypeTestsModule::getJumpTableEntrySize(Triple::ArchType JumpTableArch) {
   switch (JumpTableArch) {
   case Triple::x86:
   case Triple::x86_64:
@@ -1278,33 +1273,32 @@ unsigned LowerTypeTestsModule::getJumpTableEntrySize() {
   }
 }
 
-// Create a jump table entry for the target. This consists of an instruction
-// sequence containing a relative branch to Dest. Appends inline asm text,
-// constraints and arguments to AsmOS, ConstraintOS and AsmArgs.
-void LowerTypeTestsModule::createJumpTableEntry(
-    raw_ostream &AsmOS, raw_ostream &ConstraintOS,
-    Triple::ArchType JumpTableArch, SmallVectorImpl<Value *> &AsmArgs,
-    Function *Dest) {
-  unsigned ArgIndex = AsmArgs.size();
+// Create an inline asm constant representing a jump table entry for the target.
+// This consists of an instruction sequence containing a relative branch to
+// Dest.
+InlineAsm *
+LowerTypeTestsModule::createJumpTableEntryAsm(Triple::ArchType JumpTableArch) {
+  std::string Asm;
+  raw_string_ostream AsmOS(Asm);
 
   if (JumpTableArch == Triple::x86 || JumpTableArch == Triple::x86_64) {
     bool Endbr = false;
     if (const auto *MD = mdconst::extract_or_null<ConstantInt>(
-          Dest->getParent()->getModuleFlag("cf-protection-branch")))
+            M.getModuleFlag("cf-protection-branch")))
       Endbr = !MD->isZero();
     if (Endbr)
       AsmOS << (JumpTableArch == Triple::x86 ? "endbr32\n" : "endbr64\n");
-    AsmOS << "jmp ${" << ArgIndex << ":c}@plt\n";
+    AsmOS << "jmp ${0:c}@plt\n";
     if (Endbr)
       AsmOS << ".balign 16, 0xcc\n";
     else
       AsmOS << "int3\nint3\nint3\n";
   } else if (JumpTableArch == Triple::arm) {
-    AsmOS << "b $" << ArgIndex << "\n";
+    AsmOS << "b $0\n";
   } else if (JumpTableArch == Triple::aarch64) {
     if (hasBranchTargetEnforcement())
       AsmOS << "bti c\n";
-    AsmOS << "b $" << ArgIndex << "\n";
+    AsmOS << "b $0\n";
   } else if (JumpTableArch == Triple::thumb) {
     if (!CanUseThumbBWJumpTable) {
       // In Armv6-M, this sequence will generate a branch without corrupting
@@ -1328,28 +1322,26 @@ void LowerTypeTestsModule::createJumpTableEntry(
             << "str r0, [sp, #4]\n"
             << "pop {r0,pc}\n"
             << ".balign 4\n"
-            << "1: .word $" << ArgIndex << " - (0b + 4)\n";
+            << "1: .word $0 - (0b + 4)\n";
     } else {
       if (hasBranchTargetEnforcement())
         AsmOS << "bti\n";
-      AsmOS << "b.w $" << ArgIndex << "\n";
+      AsmOS << "b.w $0\n";
     }
   } else if (JumpTableArch == Triple::riscv32 ||
              JumpTableArch == Triple::riscv64) {
-    AsmOS << "tail $" << ArgIndex << "@plt\n";
+    AsmOS << "tail $0 at plt\n";
   } else if (JumpTableArch == Triple::loongarch64) {
-    AsmOS << "pcalau12i $$t0, %pc_hi20($" << ArgIndex << ")\n"
-          << "jirl $$r0, $$t0, %pc_lo12($" << ArgIndex << ")\n";
+    AsmOS << "pcalau12i $$t0, %pc_hi20($0)\n"
+          << "jirl $$r0, $$t0, %pc_lo12($0)\n";
   } else {
     report_fatal_error("Unsupported architecture for jump tables");
   }
 
-  ConstraintOS << (ArgIndex > 0 ? ",s" : "s");
-  AsmArgs.push_back(Dest);
-}
-
-Type *LowerTypeTestsModule::getJumpTableEntryType() {
-  return ArrayType::get(Int8Ty, getJumpTableEntrySize());
+  return InlineAsm::get(
+      FunctionType::get(Type::getVoidTy(M.getContext()), PtrTy, false),
+      AsmOS.str(), "s",
+      /*hasSideEffects=*/true);
 }
 
 /// Given a disjoint set of type identifiers and functions, build the bit sets
@@ -1498,12 +1490,17 @@ Triple::ArchType LowerTypeTestsModule::selectJumpTableArmEncoding(
 }
 
 void LowerTypeTestsModule::createJumpTable(
-    Function *F, ArrayRef<GlobalTypeMember *> Functions) {
+    Function *F, ArrayRef<GlobalTypeMember *> Functions, Triple::ArchType JumpTableArch) {
   std::string AsmStr, ConstraintStr;
   raw_string_ostream AsmOS(AsmStr), ConstraintOS(ConstraintStr);
   SmallVector<Value *, 16> AsmArgs;
   AsmArgs.reserve(Functions.size() * 2);
 
+  BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", F);
+  IRBuilder<> IRB(BB);
+
+  InlineAsm *JumpTableAsm = createJumpTableEntryAsm(JumpTableArch);
+
   // Check if all entries have the NoUnwind attribute.
   // If all entries have it, we can safely mark the
   // cfi.jumptable as NoUnwind, otherwise, direct calls
@@ -1514,12 +1511,12 @@ void LowerTypeTestsModule::createJumpTable(
              ->hasFnAttribute(llvm::Attribute::NoUnwind)) {
       areAllEntriesNounwind = false;
     }
-    createJumpTableEntry(AsmOS, ConstraintOS, JumpTableArch, AsmArgs,
-                         cast<Function>(GTM->getGlobal()));
+    IRB.CreateCall(JumpTableAsm, GTM->getGlobal());
   }
+  IRB.CreateUnreachable();
 
   // Align the whole table by entry size.
-  F->setAlignment(Align(getJumpTableEntrySize()));
+  F->setAlignment(Align(getJumpTableEntrySize(JumpTableArch)));
   // Skip prologue.
   // Disabled on win32 due to https://llvm.org/bugs/show_bug.cgi?id=28641#c3.
   // Luckily, this function does not get any prologue even without the
@@ -1568,21 +1565,6 @@ void LowerTypeTestsModule::createJumpTable(
 
   // Make sure we do not inline any calls to the cfi.jumptable.
   F->addFnAttr(Attribute::NoInline);
-
-  BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", F);
-  IRBuilder<> IRB(BB);
-
-  SmallVector<Type *, 16> ArgTypes;
-  ArgTypes.reserve(AsmArgs.size());
-  for (const auto &Arg : AsmArgs)
-    ArgTypes.push_back(Arg->getType());
-  InlineAsm *JumpTableAsm =
-      InlineAsm::get(FunctionType::get(IRB.getVoidTy(), ArgTypes, false),
-                     AsmOS.str(), ConstraintOS.str(),
-                     /*hasSideEffects=*/true);
-
-  IRB.CreateCall(JumpTableAsm, AsmArgs);
-  IRB.CreateUnreachable();
 }
 
 /// Given a disjoint set of type identifiers and functions, build a jump table
@@ -1669,11 +1651,11 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative(
 
   // Decide on the jump table encoding, so that we know how big the
   // entries will be.
-  JumpTableArch = selectJumpTableArmEncoding(Functions);
+  Triple::ArchType JumpTableArch = selectJumpTableArmEncoding(Functions);
 
   // Build a simple layout based on the regular layout of jump tables.
   DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout;
-  unsigned EntrySize = getJumpTableEntrySize();
+  unsigned EntrySize = getJumpTableEntrySize(JumpTableArch);
   for (unsigned I = 0; I != Functions.size(); ++I)
     GlobalLayout[Functions[I]] = I * EntrySize;
 
@@ -1684,7 +1666,7 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative(
                        M.getDataLayout().getProgramAddressSpace(),
                        ".cfi.jumptable", &M);
   ArrayType *JumpTableType =
-      ArrayType::get(getJumpTableEntryType(), Functions.size());
+      ArrayType::get(ArrayType::get(Int8Ty, EntrySize), Functions.size());
   auto JumpTable = ConstantExpr::getPointerCast(
       JumpTableFn, PointerType::getUnqual(M.getContext()));
 
@@ -1742,7 +1724,7 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative(
     }
   }
 
-  createJumpTable(JumpTableFn, Functions);
+  createJumpTable(JumpTableFn, Functions, JumpTableArch);
 }
 
 /// Assign a dummy layout using an incrementing counter, tag each function
diff --git a/llvm/test/Transforms/LowerTypeTests/aarch64-jumptable.ll b/llvm/test/Transforms/LowerTypeTests/aarch64-jumptable.ll
index 3464a748778b6..5ac6d00d9afd1 100644
--- a/llvm/test/Transforms/LowerTypeTests/aarch64-jumptable.ll
+++ b/llvm/test/Transforms/LowerTypeTests/aarch64-jumptable.ll
@@ -53,6 +53,7 @@ define i1 @foo(ptr %p) {
 ; AARCH64-LABEL: define private void @.cfi.jumptable
 ; AARCH64-SAME: () #[[ATTR1:[0-9]+]] align 8 {
 ; AARCH64-NEXT:  entry:
-; AARCH64-NEXT:    call void asm sideeffect "bti c\0Ab $0\0Abti c\0Ab $1\0A", "s,s"(ptr @f.cfi, ptr @g.cfi)
+; AARCH64-NEXT:    call void asm sideeffect "bti c\0Ab $0\0A", "s"(ptr @f.cfi)
+; AARCH64-NEXT:    call void asm sideeffect "bti c\0Ab $0\0A", "s"(ptr @g.cfi)
 ; AARCH64-NEXT:    unreachable
 ;
diff --git a/llvm/test/Transforms/LowerTypeTests/cfi-direct-call1.ll b/llvm/test/Transforms/LowerTypeTests/cfi-direct-call1.ll
index 3afb4875ca288..16e8dcc1d6f4c 100644
--- a/llvm/test/Transforms/LowerTypeTests/cfi-direct-call1.ll
+++ b/llvm/test/Transforms/LowerTypeTests/cfi-direct-call1.ll
@@ -73,7 +73,10 @@ entry:
 ; Check which jump table entries are created
 ; FULL: define private void @.cfi.jumptable(){{.*}}
 ; FULL-NEXT: entry:
-; FULL-NEXT: call void asm{{.*}}local_func1.cfi{{.*}}local_func2.cfi{{.*}}extern_weak{{.*}}extern_decl
+; FULL-NEXT: call void asm{{.*}}local_func1.cfi
+; FULL-NEXT: call void asm{{.*}}local_func2.cfi
+; FULL-NEXT: call void asm{{.*}}extern_weak
+; FULL-NEXT: call void asm{{.*}}extern_decl
 
 ; Make sure all local functions have been renamed to <name>.cfi
 ; THIN: define hidden i32 @local_func1.cfi()
diff --git a/llvm/test/Transforms/LowerTypeTests/function.ll b/llvm/test/Transforms/LowerTypeTests/function.ll
index 1504d6a47847e..f80e99ebfba2c 100644
--- a/llvm/test/Transforms/LowerTypeTests/function.ll
+++ b/llvm/test/Transforms/LowerTypeTests/function.ll
@@ -73,16 +73,10 @@ define i1 @foo(ptr %p) {
 ; X86-SAME: int3
 ; X86-SAME: int3
 ; X86-SAME: int3
-; X86-SAME: jmp ${1:c}@plt
-; X86-SAME: int3
-; X86-SAME: int3
-; X86-SAME: int3
 
 ; ARM:      b $0
-; ARM-SAME: b $1
 
 ; THUMB:      b.w $0
-; THUMB-SAME: b.w $1
 
 ; THUMBV6M:      push {r0,r1}
 ; THUMBV6M-SAME: ldr r0, 1f
@@ -91,23 +85,37 @@ define i1 @foo(ptr %p) {
 ; THUMBV6M-SAME: pop {r0,pc}
 ; THUMBV6M-SAME: .balign 4
 ; THUMBV6M-SAME: 1: .word $0 - (0b + 4)
-; THUMBV6M-SAME: push {r0,r1}
+
+; RISCV:      tail $0 at plt
+
+; LOONGARCH64:      pcalau12i $$t0, %pc_hi20($0)
+; LOONGARCH64-SAME: jirl $$r0, $$t0, %pc_lo12($0)
+
+; NATIVE-SAME: "s"(ptr @f.cfi)
+
+; X86-NEXT: jmp ${0:c}@plt
+; X86-SAME: int3
+; X86-SAME: int3
+; X86-SAME: int3
+
+; ARM-NEXT: b $0
+
+; THUMB-NEXT: b.w $0
+
+; THUMBV6M-NEXT: push {r0,r1}
 ; THUMBV6M-SAME: ldr r0, 1f
 ; THUMBV6M-SAME: 0: add r0, r0, pc
 ; THUMBV6M-SAME: str r0, [sp, #4]
 ; THUMBV6M-SAME: pop {r0,pc}
 ; THUMBV6M-SAME: .balign 4
-; THUMBV6M-SAME: 1: .word $1 - (0b + 4)
+; THUMBV6M-SAME: 1: .word $0 - (0b + 4)
 
-; RISCV:      tail $0 at plt
-; RISCV-SAME: tail $1 at plt
+; RISCV-NEXT: tail $0 at plt
 
-; LOONGARCH64:      pcalau12i $$t0, %pc_hi20($0)
+; LOONGARCH64-NEXT: pcalau12i $$t0, %pc_hi20($0)
 ; LOONGARCH64-SAME: jirl $$r0, $$t0, %pc_lo12($0)
-; LOONGARCH64-SAME: pcalau12i $$t0, %pc_hi20($1)
-; LOONGARCH64-SAME: jirl $$r0, $$t0, %pc_lo12($1)
 
-; NATIVE-SAME: "s,s"(ptr @f.cfi, ptr @g.cfi)
+; NATIVE-SAME: "s"(ptr @g.cfi)
 
 ; X86-LINUX: attributes #[[ATTR]] = { naked nocf_check noinline }
 ; X86-WIN32: attributes #[[ATTR]] = { nocf_check noinline }
diff --git a/llvm/test/Transforms/LowerTypeTests/x86-jumptable.ll b/llvm/test/Transforms/LowerTypeTests/x86-jumptable.ll
index f56d30be37959..76acf9469785d 100644
--- a/llvm/test/Transforms/LowerTypeTests/x86-jumptable.ll
+++ b/llvm/test/Transforms/LowerTypeTests/x86-jumptable.ll
@@ -25,7 +25,9 @@ define i1 @foo(ptr %p) {
 
 ; X86:         define private void @.cfi.jumptable() #[[#ATTR:]] align 16 {
 ; X86-NEXT:    entry:
-; X86_32-NEXT:   call void asm sideeffect "endbr32\0Ajmp ${0:c}@plt\0A.balign 16, 0xcc\0Aendbr32\0Ajmp ${1:c}@plt\0A.balign 16, 0xcc\0A", "s,s"(ptr @f.cfi, ptr @g.cfi)
-; X86_64-NEXT:   call void asm sideeffect "endbr64\0Ajmp ${0:c}@plt\0A.balign 16, 0xcc\0Aendbr64\0Ajmp ${1:c}@plt\0A.balign 16, 0xcc\0A", "s,s"(ptr @f.cfi, ptr @g.cfi)
+; X86_32-NEXT:   call void asm sideeffect "endbr32\0Ajmp ${0:c}@plt\0A.balign 16, 0xcc\0A", "s"(ptr @f.cfi)
+; X86_32-NEXT:   call void asm sideeffect "endbr32\0Ajmp ${0:c}@plt\0A.balign 16, 0xcc\0A", "s"(ptr @g.cfi)
+; X86_64-NEXT:   call void asm sideeffect "endbr64\0Ajmp ${0:c}@plt\0A.balign 16, 0xcc\0A", "s"(ptr @f.cfi)
+; X86_64-NEXT:   call void asm sideeffect "endbr64\0Ajmp ${0:c}@plt\0A.balign 16, 0xcc\0A", "s"(ptr @g.cfi)
 
 ; X86_64: attributes #[[#ATTR]] = { naked nocf_check noinline }
diff --git a/llvm/test/Transforms/MergeFunc/cfi-thunk-merging.ll b/llvm/test/Transforms/MergeFunc/cfi-thunk-merging.ll
index 562cc1a973d81..f4225f95538a0 100644
--- a/llvm/test/Transforms/MergeFunc/cfi-thunk-merging.ll
+++ b/llvm/test/Transforms/MergeFunc/cfi-thunk-merging.ll
@@ -205,6 +205,7 @@ attributes #3 = { noreturn nounwind }
 ; LOWERTYPETESTS-LABEL: define private void @.cfi.jumptable
 ; LOWERTYPETESTS-SAME: () #[[ATTR3:[0-9]+]] align 8 {
 ; LOWERTYPETESTS-NEXT:  entry:
-; LOWERTYPETESTS-NEXT:    call void asm sideeffect "jmp ${0:c}@plt\0Aint3\0Aint3\0Aint3\0Ajmp ${1:c}@plt\0Aint3\0Aint3\0Aint3\0A", "s,s"(ptr @f, ptr @f_thunk)
+; LOWERTYPETESTS-NEXT:    call void asm sideeffect "jmp ${0:c}@plt\0Aint3\0Aint3\0Aint3\0A", "s"(ptr @f)
+; LOWERTYPETESTS-NEXT:    call void asm sideeffect "jmp ${0:c}@plt\0Aint3\0Aint3\0Aint3\0A", "s"(ptr @f_thunk)
 ; LOWERTYPETESTS-NEXT:    unreachable
 ;

``````````

</details>


https://github.com/llvm/llvm-project/pull/136265


More information about the llvm-commits mailing list