[llvm] [RISC-V] Adjust trampoline code for branch control flow protection (PR #141949)

Jesse Huang via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 3 03:21:46 PDT 2025


https://github.com/jaidTw updated https://github.com/llvm/llvm-project/pull/141949

>From 82ae8ef917258e0cbd40ce4ce11b2a6176922ed0 Mon Sep 17 00:00:00 2001
From: Jesse Huang <jesse.huang at sifive.com>
Date: Thu, 29 May 2025 06:25:50 -0700
Subject: [PATCH 1/4] [RISC-V] Adjust trampoline code for branch control flow
 protection

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 117 +++++++++++++-----
 .../test/CodeGen/RISCV/rv64-trampoline-cfi.ll |  95 ++++++++++++++
 2 files changed, 184 insertions(+), 28 deletions(-)
 create mode 100644 llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 0a849f49116ee..2bcbe18e9beed 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -29,6 +29,7 @@
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/SDPatternMatch.h"
 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
@@ -8295,9 +8296,23 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
   //     16: <StaticChainOffset>
   //     24: <FunctionAddressOffset>
   //     32:
-
-  constexpr unsigned StaticChainOffset = 16;
-  constexpr unsigned FunctionAddressOffset = 24;
+  // Offset with branch control flow protection enabled:
+  //      0: lpad    <imm20>
+  //      4: auipc   t3, 0
+  //      8: ld      t0, 28(t3)
+  //     12: ld      t3, 20(t3)
+  //     16: lui     t2, <imm20>
+  //     20: jalr    t0
+  //     24: <StaticChainOffset>
+  //     32: <FunctionAddressOffset>
+  //     40:
+
+  const bool HasCFBranch =
+      Subtarget.hasStdExtZicfilp() &&
+      DAG.getMMI()->getModule()->getModuleFlag("cf-protection-branch");
+  const unsigned StaticChainIdx = HasCFBranch ? 6 : 4;
+  const unsigned StaticChainOffset = StaticChainIdx * 4;
+  const unsigned FunctionAddressOffset = StaticChainOffset + 8;
 
   const MCSubtargetInfo *STI = getTargetMachine().getMCSubtargetInfo();
   assert(STI);
@@ -8310,35 +8325,77 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
   };
 
   SDValue OutChains[6];
-
-  uint32_t Encodings[] = {
-      // auipc t2, 0
-      // Loads the current PC into t2.
-      GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
-      // ld t0, 24(t2)
-      // Loads the function address into t0. Note that we are using offsets
-      // pc-relative to the first instruction of the trampoline.
-      GetEncoding(
-          MCInstBuilder(RISCV::LD).addReg(RISCV::X5).addReg(RISCV::X7).addImm(
-              FunctionAddressOffset)),
-      // ld t2, 16(t2)
-      // Load the value of the static chain.
-      GetEncoding(
-          MCInstBuilder(RISCV::LD).addReg(RISCV::X7).addReg(RISCV::X7).addImm(
-              StaticChainOffset)),
-      // jalr t0
-      // Jump to the function.
-      GetEncoding(MCInstBuilder(RISCV::JALR)
-                      .addReg(RISCV::X0)
-                      .addReg(RISCV::X5)
-                      .addImm(0))};
+  SDValue OutChainsLPAD[8];
+  if (HasCFBranch)
+    assert(std::size(OutChainsLPAD) == StaticChainIdx + 2);
+  else
+    assert(std::size(OutChains) == StaticChainIdx + 2);
+
+  SmallVector<uint32_t> Encodings;
+  if (!HasCFBranch) {
+    Encodings.append(
+        {// auipc t2, 0
+         // Loads the current PC into t2.
+         GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
+         // ld t0, 24(t2)
+         // Loads the function address into t0. Note that we are using offsets
+         // pc-relative to the first instruction of the trampoline.
+         GetEncoding(MCInstBuilder(RISCV::LD)
+                         .addReg(RISCV::X5)
+                         .addReg(RISCV::X7)
+                         .addImm(FunctionAddressOffset)),
+         // ld t2, 16(t2)
+         // Load the value of the static chain.
+         GetEncoding(MCInstBuilder(RISCV::LD)
+                         .addReg(RISCV::X7)
+                         .addReg(RISCV::X7)
+                         .addImm(StaticChainOffset)),
+         // jalr t0
+         // Jump to the function.
+         GetEncoding(MCInstBuilder(RISCV::JALR)
+                         .addReg(RISCV::X0)
+                         .addReg(RISCV::X5)
+                         .addImm(0))});
+  } else {
+    Encodings.append(
+        {// auipc x0, <imm20> (lpad <imm20>)
+         // Landing pad.
+         GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X0).addImm(0)),
+         // auipc t3, 0
+         // Loads the current PC into t3.
+         GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X28).addImm(0)),
+         // ld t0, (FunctionAddressOffset - 4)(t3)
+         // Loads the function address into t0. Note that we are using offsets
+         // pc-relative to the SECOND instruction of the trampoline.
+         GetEncoding(MCInstBuilder(RISCV::LD)
+                         .addReg(RISCV::X5)
+                         .addReg(RISCV::X28)
+                         .addImm(FunctionAddressOffset - 4)),
+         // ld t3, (StaticChainOffset - 4)(t3)
+         // Load the value of the static chain.
+         GetEncoding(MCInstBuilder(RISCV::LD)
+                         .addReg(RISCV::X28)
+                         .addReg(RISCV::X28)
+                         .addImm(StaticChainOffset - 4)),
+         // lui t2, <imm20>
+         // Setup the landing pad value.
+         GetEncoding(MCInstBuilder(RISCV::LUI).addReg(RISCV::X7).addImm(0)),
+         // jalr t0
+         // Jump to the function.
+         GetEncoding(MCInstBuilder(RISCV::JALR)
+                         .addReg(RISCV::X0)
+                         .addReg(RISCV::X5)
+                         .addImm(0))});
+  }
+
+  SDValue *OutChainsUsed = HasCFBranch ? OutChainsLPAD : OutChains;
 
   // Store encoded instructions.
   for (auto [Idx, Encoding] : llvm::enumerate(Encodings)) {
     SDValue Addr = Idx > 0 ? DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
                                          DAG.getConstant(Idx * 4, dl, MVT::i64))
                            : Trmp;
-    OutChains[Idx] = DAG.getTruncStore(
+    OutChainsUsed[Idx] = DAG.getTruncStore(
         Root, dl, DAG.getConstant(Encoding, dl, MVT::i64), Addr,
         MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32);
   }
@@ -8361,12 +8418,16 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
         DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
                     DAG.getConstant(OffsetValue.Offset, dl, MVT::i64));
     OffsetValue.Addr = Addr;
-    OutChains[Idx + 4] =
+    OutChainsUsed[Idx + StaticChainIdx] =
         DAG.getStore(Root, dl, OffsetValue.Value, Addr,
                      MachinePointerInfo(TrmpAddr, OffsetValue.Offset));
   }
 
-  SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
+  SDValue StoreToken;
+  if (HasCFBranch)
+    StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChainsLPAD);
+  else
+    StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
 
   // The end of instructions of trampoline is the same as the static chain
   // address that we computed earlier.
diff --git a/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll b/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll
new file mode 100644
index 0000000000000..304018ca0db56
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll
@@ -0,0 +1,95 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -O0 -mtriple=riscv64 -mattr=+experimental-zicfilp -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV64 %s
+; RUN: llc -O0 -mtriple=riscv64-unknown-linux-gnu -mattr=+experimental-zicfilp -verify-machineinstrs < %s \
+; RUN:   | FileCheck -check-prefix=RV64-LINUX %s
+
+declare void @llvm.init.trampoline(ptr, ptr, ptr)
+declare ptr @llvm.adjust.trampoline(ptr)
+declare i64 @f(ptr nest, i64)
+
+define i64 @test0(i64 %n, ptr %p) nounwind {
+; RV64-LABEL: test0:
+; RV64:       # %bb.0:
+; RV64-NEXT:    lpad 0
+; RV64-NEXT:    addi sp, sp, -64
+; RV64-NEXT:    sd ra, 56(sp) # 8-byte Folded Spill
+; RV64-NEXT:    sd a0, 0(sp) # 8-byte Folded Spill
+; RV64-NEXT:    lui a0, %hi(f)
+; RV64-NEXT:    addi a0, a0, %lo(f)
+; RV64-NEXT:    sd a0, 48(sp)
+; RV64-NEXT:    sd a1, 40(sp)
+; RV64-NEXT:    li a0, 951
+; RV64-NEXT:    sw a0, 32(sp)
+; RV64-NEXT:    li a0, 23
+; RV64-NEXT:    sw a0, 16(sp)
+; RV64-NEXT:    lui a0, 40
+; RV64-NEXT:    addiw a0, a0, 103
+; RV64-NEXT:    sw a0, 36(sp)
+; RV64-NEXT:    lui a0, 5348
+; RV64-NEXT:    addiw a0, a0, -509
+; RV64-NEXT:    sw a0, 28(sp)
+; RV64-NEXT:    lui a0, 7395
+; RV64-NEXT:    addiw a0, a0, 643
+; RV64-NEXT:    sw a0, 24(sp)
+; RV64-NEXT:    lui a0, 1
+; RV64-NEXT:    addiw a0, a0, -489
+; RV64-NEXT:    sw a0, 20(sp)
+; RV64-NEXT:    addi a1, sp, 40
+; RV64-NEXT:    addi a0, sp, 16
+; RV64-NEXT:    sd a0, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT:    call __clear_cache
+; RV64-NEXT:    ld a0, 0(sp) # 8-byte Folded Reload
+; RV64-NEXT:    ld a1, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT:    jalr a1
+; RV64-NEXT:    ld ra, 56(sp) # 8-byte Folded Reload
+; RV64-NEXT:    addi sp, sp, 64
+; RV64-NEXT:    ret
+;
+; RV64-LINUX-LABEL: test0:
+; RV64-LINUX:       # %bb.0:
+; RV64-LINUX-NEXT:    lpad 0
+; RV64-LINUX-NEXT:    addi sp, sp, -64
+; RV64-LINUX-NEXT:    sd ra, 56(sp) # 8-byte Folded Spill
+; RV64-LINUX-NEXT:    sd a0, 0(sp) # 8-byte Folded Spill
+; RV64-LINUX-NEXT:    lui a0, %hi(f)
+; RV64-LINUX-NEXT:    addi a0, a0, %lo(f)
+; RV64-LINUX-NEXT:    sd a0, 48(sp)
+; RV64-LINUX-NEXT:    sd a1, 40(sp)
+; RV64-LINUX-NEXT:    li a0, 951
+; RV64-LINUX-NEXT:    sw a0, 32(sp)
+; RV64-LINUX-NEXT:    li a0, 23
+; RV64-LINUX-NEXT:    sw a0, 16(sp)
+; RV64-LINUX-NEXT:    lui a0, 40
+; RV64-LINUX-NEXT:    addiw a0, a0, 103
+; RV64-LINUX-NEXT:    sw a0, 36(sp)
+; RV64-LINUX-NEXT:    lui a0, 5348
+; RV64-LINUX-NEXT:    addiw a0, a0, -509
+; RV64-LINUX-NEXT:    sw a0, 28(sp)
+; RV64-LINUX-NEXT:    lui a0, 7395
+; RV64-LINUX-NEXT:    addiw a0, a0, 643
+; RV64-LINUX-NEXT:    sw a0, 24(sp)
+; RV64-LINUX-NEXT:    lui a0, 1
+; RV64-LINUX-NEXT:    addiw a0, a0, -489
+; RV64-LINUX-NEXT:    sw a0, 20(sp)
+; RV64-LINUX-NEXT:    addi a1, sp, 40
+; RV64-LINUX-NEXT:    addi a0, sp, 16
+; RV64-LINUX-NEXT:    sd a0, 8(sp) # 8-byte Folded Spill
+; RV64-LINUX-NEXT:    li a2, 0
+; RV64-LINUX-NEXT:    call __riscv_flush_icache
+; RV64-LINUX-NEXT:    ld a0, 0(sp) # 8-byte Folded Reload
+; RV64-LINUX-NEXT:    ld a1, 8(sp) # 8-byte Folded Reload
+; RV64-LINUX-NEXT:    jalr a1
+; RV64-LINUX-NEXT:    ld ra, 56(sp) # 8-byte Folded Reload
+; RV64-LINUX-NEXT:    addi sp, sp, 64
+; RV64-LINUX-NEXT:    ret
+  %alloca = alloca [40 x i8], align 8
+  call void @llvm.init.trampoline(ptr %alloca, ptr @f, ptr %p)
+  %tramp = call ptr @llvm.adjust.trampoline(ptr %alloca)
+  %ret = call i64 %tramp(i64 %n)
+  ret i64 %ret
+}
+
+!llvm.module.flags = !{!0}
+
+!0 = !{i32 8, !"cf-protection-branch", i32 1}

>From a286cf6a3ca6e1a8aaec06a707483c6b27ec2832 Mon Sep 17 00:00:00 2001
From: Jesse Huang <jesse.huang at sifive.com>
Date: Mon, 2 Jun 2025 04:20:16 -0700
Subject: [PATCH 2/4] [RISCV] Use software-guarded jump in the trampoline code

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 30 ++++++------
 .../test/CodeGen/RISCV/rv64-trampoline-cfi.ll | 46 ++++++++++---------
 2 files changed, 38 insertions(+), 38 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 2bcbe18e9beed..1a84e20749410 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8299,18 +8299,17 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
   // Offset with branch control flow protection enabled:
   //      0: lpad    <imm20>
   //      4: auipc   t3, 0
-  //      8: ld      t0, 28(t3)
+  //      8: ld      t2, 28(t3)
   //     12: ld      t3, 20(t3)
-  //     16: lui     t2, <imm20>
-  //     20: jalr    t0
-  //     24: <StaticChainOffset>
-  //     32: <FunctionAddressOffset>
-  //     40:
+  //     16: jalr    t2
+  //     20: <StaticChainOffset>
+  //     28: <FunctionAddressOffset>
+  //     36:
 
   const bool HasCFBranch =
       Subtarget.hasStdExtZicfilp() &&
       DAG.getMMI()->getModule()->getModuleFlag("cf-protection-branch");
-  const unsigned StaticChainIdx = HasCFBranch ? 6 : 4;
+  const unsigned StaticChainIdx = HasCFBranch ? 5 : 4;
   const unsigned StaticChainOffset = StaticChainIdx * 4;
   const unsigned FunctionAddressOffset = StaticChainOffset + 8;
 
@@ -8325,7 +8324,7 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
   };
 
   SDValue OutChains[6];
-  SDValue OutChainsLPAD[8];
+  SDValue OutChainsLPAD[7];
   if (HasCFBranch)
     assert(std::size(OutChainsLPAD) == StaticChainIdx + 2);
   else
@@ -8364,11 +8363,11 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
          // auipc t3, 0
          // Loads the current PC into t3.
          GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X28).addImm(0)),
-         // ld t0, (FunctionAddressOffset - 4)(t3)
-         // Loads the function address into t0. Note that we are using offsets
+         // ld t2, (FunctionAddressOffset - 4)(t3)
+         // Loads the function address into t2. Note that we are using offsets
          // pc-relative to the SECOND instruction of the trampoline.
          GetEncoding(MCInstBuilder(RISCV::LD)
-                         .addReg(RISCV::X5)
+                         .addReg(RISCV::X7)
                          .addReg(RISCV::X28)
                          .addImm(FunctionAddressOffset - 4)),
          // ld t3, (StaticChainOffset - 4)(t3)
@@ -8377,14 +8376,11 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
                          .addReg(RISCV::X28)
                          .addReg(RISCV::X28)
                          .addImm(StaticChainOffset - 4)),
-         // lui t2, <imm20>
-         // Setup the landing pad value.
-         GetEncoding(MCInstBuilder(RISCV::LUI).addReg(RISCV::X7).addImm(0)),
-         // jalr t0
-         // Jump to the function.
+         // jalr t2
+         // Software-guarded jump to the function.
          GetEncoding(MCInstBuilder(RISCV::JALR)
                          .addReg(RISCV::X0)
-                         .addReg(RISCV::X5)
+                         .addReg(RISCV::X7)
                          .addImm(0))});
   }
 
diff --git a/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll b/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll
index 304018ca0db56..d328b3c5d5f02 100644
--- a/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll
+++ b/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll
@@ -17,25 +17,27 @@ define i64 @test0(i64 %n, ptr %p) nounwind {
 ; RV64-NEXT:    sd a0, 0(sp) # 8-byte Folded Spill
 ; RV64-NEXT:    lui a0, %hi(f)
 ; RV64-NEXT:    addi a0, a0, %lo(f)
-; RV64-NEXT:    sd a0, 48(sp)
-; RV64-NEXT:    sd a1, 40(sp)
-; RV64-NEXT:    li a0, 951
-; RV64-NEXT:    sw a0, 32(sp)
+; RV64-NEXT:    sw a0, 44(sp)
+; RV64-NEXT:    srli a0, a0, 32
+; RV64-NEXT:    sw a0, 48(sp)
+; RV64-NEXT:    sw a1, 36(sp)
+; RV64-NEXT:    srli a0, a1, 32
+; RV64-NEXT:    sw a0, 40(sp)
 ; RV64-NEXT:    li a0, 23
 ; RV64-NEXT:    sw a0, 16(sp)
-; RV64-NEXT:    lui a0, 40
+; RV64-NEXT:    lui a0, 56
 ; RV64-NEXT:    addiw a0, a0, 103
-; RV64-NEXT:    sw a0, 36(sp)
-; RV64-NEXT:    lui a0, 5348
+; RV64-NEXT:    sw a0, 32(sp)
+; RV64-NEXT:    lui a0, 4324
 ; RV64-NEXT:    addiw a0, a0, -509
 ; RV64-NEXT:    sw a0, 28(sp)
-; RV64-NEXT:    lui a0, 7395
-; RV64-NEXT:    addiw a0, a0, 643
+; RV64-NEXT:    lui a0, 6371
+; RV64-NEXT:    addiw a0, a0, 899
 ; RV64-NEXT:    sw a0, 24(sp)
 ; RV64-NEXT:    lui a0, 1
 ; RV64-NEXT:    addiw a0, a0, -489
 ; RV64-NEXT:    sw a0, 20(sp)
-; RV64-NEXT:    addi a1, sp, 40
+; RV64-NEXT:    addi a1, sp, 36
 ; RV64-NEXT:    addi a0, sp, 16
 ; RV64-NEXT:    sd a0, 8(sp) # 8-byte Folded Spill
 ; RV64-NEXT:    call __clear_cache
@@ -54,25 +56,27 @@ define i64 @test0(i64 %n, ptr %p) nounwind {
 ; RV64-LINUX-NEXT:    sd a0, 0(sp) # 8-byte Folded Spill
 ; RV64-LINUX-NEXT:    lui a0, %hi(f)
 ; RV64-LINUX-NEXT:    addi a0, a0, %lo(f)
-; RV64-LINUX-NEXT:    sd a0, 48(sp)
-; RV64-LINUX-NEXT:    sd a1, 40(sp)
-; RV64-LINUX-NEXT:    li a0, 951
-; RV64-LINUX-NEXT:    sw a0, 32(sp)
+; RV64-LINUX-NEXT:    sw a0, 44(sp)
+; RV64-LINUX-NEXT:    srli a0, a0, 32
+; RV64-LINUX-NEXT:    sw a0, 48(sp)
+; RV64-LINUX-NEXT:    sw a1, 36(sp)
+; RV64-LINUX-NEXT:    srli a0, a1, 32
+; RV64-LINUX-NEXT:    sw a0, 40(sp)
 ; RV64-LINUX-NEXT:    li a0, 23
 ; RV64-LINUX-NEXT:    sw a0, 16(sp)
-; RV64-LINUX-NEXT:    lui a0, 40
+; RV64-LINUX-NEXT:    lui a0, 56
 ; RV64-LINUX-NEXT:    addiw a0, a0, 103
-; RV64-LINUX-NEXT:    sw a0, 36(sp)
-; RV64-LINUX-NEXT:    lui a0, 5348
+; RV64-LINUX-NEXT:    sw a0, 32(sp)
+; RV64-LINUX-NEXT:    lui a0, 4324
 ; RV64-LINUX-NEXT:    addiw a0, a0, -509
 ; RV64-LINUX-NEXT:    sw a0, 28(sp)
-; RV64-LINUX-NEXT:    lui a0, 7395
-; RV64-LINUX-NEXT:    addiw a0, a0, 643
+; RV64-LINUX-NEXT:    lui a0, 6371
+; RV64-LINUX-NEXT:    addiw a0, a0, 899
 ; RV64-LINUX-NEXT:    sw a0, 24(sp)
 ; RV64-LINUX-NEXT:    lui a0, 1
 ; RV64-LINUX-NEXT:    addiw a0, a0, -489
 ; RV64-LINUX-NEXT:    sw a0, 20(sp)
-; RV64-LINUX-NEXT:    addi a1, sp, 40
+; RV64-LINUX-NEXT:    addi a1, sp, 36
 ; RV64-LINUX-NEXT:    addi a0, sp, 16
 ; RV64-LINUX-NEXT:    sd a0, 8(sp) # 8-byte Folded Spill
 ; RV64-LINUX-NEXT:    li a2, 0
@@ -83,7 +87,7 @@ define i64 @test0(i64 %n, ptr %p) nounwind {
 ; RV64-LINUX-NEXT:    ld ra, 56(sp) # 8-byte Folded Reload
 ; RV64-LINUX-NEXT:    addi sp, sp, 64
 ; RV64-LINUX-NEXT:    ret
-  %alloca = alloca [40 x i8], align 8
+  %alloca = alloca [36 x i8], align 8
   call void @llvm.init.trampoline(ptr %alloca, ptr @f, ptr %p)
   %tramp = call ptr @llvm.adjust.trampoline(ptr %alloca)
   %ret = call i64 %tramp(i64 %n)

>From 0e0af2b21a1aecfa0777ed9c749aba861c57dd7e Mon Sep 17 00:00:00 2001
From: Jesse Huang <jesse.huang at sifive.com>
Date: Tue, 3 Jun 2025 02:51:37 -0700
Subject: [PATCH 3/4] Use SmallVector to handle both trampoline sequence

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 26 ++++++---------------
 1 file changed, 7 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 1a84e20749410..60d145a9cc5b5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8323,12 +8323,7 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
     return Encoding;
   };
 
-  SDValue OutChains[6];
-  SDValue OutChainsLPAD[7];
-  if (HasCFBranch)
-    assert(std::size(OutChainsLPAD) == StaticChainIdx + 2);
-  else
-    assert(std::size(OutChains) == StaticChainIdx + 2);
+  SmallVector<SDValue> OutChains;
 
   SmallVector<uint32_t> Encodings;
   if (!HasCFBranch) {
@@ -8384,16 +8379,14 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
                          .addImm(0))});
   }
 
-  SDValue *OutChainsUsed = HasCFBranch ? OutChainsLPAD : OutChains;
-
   // Store encoded instructions.
   for (auto [Idx, Encoding] : llvm::enumerate(Encodings)) {
     SDValue Addr = Idx > 0 ? DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
                                          DAG.getConstant(Idx * 4, dl, MVT::i64))
                            : Trmp;
-    OutChainsUsed[Idx] = DAG.getTruncStore(
+    OutChains.push_back(DAG.getTruncStore(
         Root, dl, DAG.getConstant(Encoding, dl, MVT::i64), Addr,
-        MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32);
+        MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32));
   }
 
   // Now store the variable part of the trampoline.
@@ -8409,21 +8402,16 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
       {StaticChainOffset, StaticChain},
       {FunctionAddressOffset, FunctionAddress},
   };
-  for (auto [Idx, OffsetValue] : llvm::enumerate(OffsetValues)) {
+  for (auto &OffsetValue : OffsetValues) {
     SDValue Addr =
         DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
                     DAG.getConstant(OffsetValue.Offset, dl, MVT::i64));
     OffsetValue.Addr = Addr;
-    OutChainsUsed[Idx + StaticChainIdx] =
-        DAG.getStore(Root, dl, OffsetValue.Value, Addr,
-                     MachinePointerInfo(TrmpAddr, OffsetValue.Offset));
+    OutChains.push_back(DAG.getStore(Root, dl, OffsetValue.Value, Addr,
+                     MachinePointerInfo(TrmpAddr, OffsetValue.Offset)));
   }
 
-  SDValue StoreToken;
-  if (HasCFBranch)
-    StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChainsLPAD);
-  else
-    StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
+  SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
 
   // The end of instructions of trampoline is the same as the static chain
   // address that we computed earlier.

>From e7360dfc563694dc48a9da419447add36d4af43a Mon Sep 17 00:00:00 2001
From: Jesse Huang <jesse.huang at sifive.com>
Date: Tue, 3 Jun 2025 03:21:32 -0700
Subject: [PATCH 4/4] fixup! Add assertion to ensure the size of OutChains

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 60d145a9cc5b5..99e005143fae9 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8411,6 +8411,7 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
                      MachinePointerInfo(TrmpAddr, OffsetValue.Offset)));
   }
 
+  assert(OutChains.size() == StaticChainIdx + 2 && "Size of OutChains mismatch");
   SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
 
   // The end of instructions of trampoline is the same as the static chain



More information about the llvm-commits mailing list