[llvm] [RISC-V] Adjust trampoline code for branch control flow protection (PR #141949)
Jesse Huang via llvm-commits
llvm-commits at lists.llvm.org
Thu May 29 06:41:59 PDT 2025
https://github.com/jaidTw created https://github.com/llvm/llvm-project/pull/141949
It is tricky to observe the trampoline code from the lit test file, because instructions are encoded and written onto the stack
The stack of the test is organized as follow
```
56 $ra
48 $a0 f
40 $a1 p
36 00028067 jalr t0
32 000003b7 lui t2, 0
28 014e3e03 ld t3, 20(t3)
24 01ce3283 ld t0, 28(t3)
20 00000e17 auipc t3, 0
sp+16 00000017 lpad 0
>From 82ae8ef917258e0cbd40ce4ce11b2a6176922ed0 Mon Sep 17 00:00:00 2001
From: Jesse Huang <jesse.huang at sifive.com>
Date: Thu, 29 May 2025 06:25:50 -0700
Subject: [PATCH] [RISC-V] Adjust trampoline code for branch control flow
protection
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 117 +++++++++++++-----
.../test/CodeGen/RISCV/rv64-trampoline-cfi.ll | 95 ++++++++++++++
2 files changed, 184 insertions(+), 28 deletions(-)
create mode 100644 llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 0a849f49116ee..2bcbe18e9beed 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -29,6 +29,7 @@
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/SDPatternMatch.h"
#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
@@ -8295,9 +8296,23 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
// 16: <StaticChainOffset>
// 24: <FunctionAddressOffset>
// 32:
-
- constexpr unsigned StaticChainOffset = 16;
- constexpr unsigned FunctionAddressOffset = 24;
+ // Offset with branch control flow protection enabled:
+ // 0: lpad <imm20>
+ // 4: auipc t3, 0
+ // 8: ld t0, 28(t3)
+ // 12: ld t3, 20(t3)
+ // 16: lui t2, <imm20>
+ // 20: jalr t0
+ // 24: <StaticChainOffset>
+ // 32: <FunctionAddressOffset>
+ // 40:
+
+ const bool HasCFBranch =
+ Subtarget.hasStdExtZicfilp() &&
+ DAG.getMMI()->getModule()->getModuleFlag("cf-protection-branch");
+ const unsigned StaticChainIdx = HasCFBranch ? 6 : 4;
+ const unsigned StaticChainOffset = StaticChainIdx * 4;
+ const unsigned FunctionAddressOffset = StaticChainOffset + 8;
const MCSubtargetInfo *STI = getTargetMachine().getMCSubtargetInfo();
assert(STI);
@@ -8310,35 +8325,77 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
};
SDValue OutChains[6];
-
- uint32_t Encodings[] = {
- // auipc t2, 0
- // Loads the current PC into t2.
- GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
- // ld t0, 24(t2)
- // Loads the function address into t0. Note that we are using offsets
- // pc-relative to the first instruction of the trampoline.
- GetEncoding(
- MCInstBuilder(RISCV::LD).addReg(RISCV::X5).addReg(RISCV::X7).addImm(
- FunctionAddressOffset)),
- // ld t2, 16(t2)
- // Load the value of the static chain.
- GetEncoding(
- MCInstBuilder(RISCV::LD).addReg(RISCV::X7).addReg(RISCV::X7).addImm(
- StaticChainOffset)),
- // jalr t0
- // Jump to the function.
- GetEncoding(MCInstBuilder(RISCV::JALR)
- .addReg(RISCV::X0)
- .addReg(RISCV::X5)
- .addImm(0))};
+ SDValue OutChainsLPAD[8];
+ if (HasCFBranch)
+ assert(std::size(OutChainsLPAD) == StaticChainIdx + 2);
+ else
+ assert(std::size(OutChains) == StaticChainIdx + 2);
+
+ SmallVector<uint32_t> Encodings;
+ if (!HasCFBranch) {
+ Encodings.append(
+ {// auipc t2, 0
+ // Loads the current PC into t2.
+ GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
+ // ld t0, 24(t2)
+ // Loads the function address into t0. Note that we are using offsets
+ // pc-relative to the first instruction of the trampoline.
+ GetEncoding(MCInstBuilder(RISCV::LD)
+ .addReg(RISCV::X5)
+ .addReg(RISCV::X7)
+ .addImm(FunctionAddressOffset)),
+ // ld t2, 16(t2)
+ // Load the value of the static chain.
+ GetEncoding(MCInstBuilder(RISCV::LD)
+ .addReg(RISCV::X7)
+ .addReg(RISCV::X7)
+ .addImm(StaticChainOffset)),
+ // jalr t0
+ // Jump to the function.
+ GetEncoding(MCInstBuilder(RISCV::JALR)
+ .addReg(RISCV::X0)
+ .addReg(RISCV::X5)
+ .addImm(0))});
+ } else {
+ Encodings.append(
+ {// auipc x0, <imm20> (lpad <imm20>)
+ // Landing pad.
+ GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X0).addImm(0)),
+ // auipc t3, 0
+ // Loads the current PC into t3.
+ GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X28).addImm(0)),
+ // ld t0, (FunctionAddressOffset - 4)(t3)
+ // Loads the function address into t0. Note that we are using offsets
+ // pc-relative to the SECOND instruction of the trampoline.
+ GetEncoding(MCInstBuilder(RISCV::LD)
+ .addReg(RISCV::X5)
+ .addReg(RISCV::X28)
+ .addImm(FunctionAddressOffset - 4)),
+ // ld t3, (StaticChainOffset - 4)(t3)
+ // Load the value of the static chain.
+ GetEncoding(MCInstBuilder(RISCV::LD)
+ .addReg(RISCV::X28)
+ .addReg(RISCV::X28)
+ .addImm(StaticChainOffset - 4)),
+ // lui t2, <imm20>
+ // Setup the landing pad value.
+ GetEncoding(MCInstBuilder(RISCV::LUI).addReg(RISCV::X7).addImm(0)),
+ // jalr t0
+ // Jump to the function.
+ GetEncoding(MCInstBuilder(RISCV::JALR)
+ .addReg(RISCV::X0)
+ .addReg(RISCV::X5)
+ .addImm(0))});
+ }
+
+ SDValue *OutChainsUsed = HasCFBranch ? OutChainsLPAD : OutChains;
// Store encoded instructions.
for (auto [Idx, Encoding] : llvm::enumerate(Encodings)) {
SDValue Addr = Idx > 0 ? DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
DAG.getConstant(Idx * 4, dl, MVT::i64))
: Trmp;
- OutChains[Idx] = DAG.getTruncStore(
+ OutChainsUsed[Idx] = DAG.getTruncStore(
Root, dl, DAG.getConstant(Encoding, dl, MVT::i64), Addr,
MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32);
}
@@ -8361,12 +8418,16 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
DAG.getConstant(OffsetValue.Offset, dl, MVT::i64));
OffsetValue.Addr = Addr;
- OutChains[Idx + 4] =
+ OutChainsUsed[Idx + StaticChainIdx] =
DAG.getStore(Root, dl, OffsetValue.Value, Addr,
MachinePointerInfo(TrmpAddr, OffsetValue.Offset));
}
- SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
+ SDValue StoreToken;
+ if (HasCFBranch)
+ StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChainsLPAD);
+ else
+ StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
// The end of instructions of trampoline is the same as the static chain
// address that we computed earlier.
diff --git a/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll b/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll
new file mode 100644
index 0000000000000..304018ca0db56
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rv64-trampoline-cfi.ll
@@ -0,0 +1,95 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -O0 -mtriple=riscv64 -mattr=+experimental-zicfilp -verify-machineinstrs < %s \
+; RUN: | FileCheck -check-prefix=RV64 %s
+; RUN: llc -O0 -mtriple=riscv64-unknown-linux-gnu -mattr=+experimental-zicfilp -verify-machineinstrs < %s \
+; RUN: | FileCheck -check-prefix=RV64-LINUX %s
+
+declare void @llvm.init.trampoline(ptr, ptr, ptr)
+declare ptr @llvm.adjust.trampoline(ptr)
+declare i64 @f(ptr nest, i64)
+
+define i64 @test0(i64 %n, ptr %p) nounwind {
+; RV64-LABEL: test0:
+; RV64: # %bb.0:
+; RV64-NEXT: lpad 0
+; RV64-NEXT: addi sp, sp, -64
+; RV64-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
+; RV64-NEXT: sd a0, 0(sp) # 8-byte Folded Spill
+; RV64-NEXT: lui a0, %hi(f)
+; RV64-NEXT: addi a0, a0, %lo(f)
+; RV64-NEXT: sd a0, 48(sp)
+; RV64-NEXT: sd a1, 40(sp)
+; RV64-NEXT: li a0, 951
+; RV64-NEXT: sw a0, 32(sp)
+; RV64-NEXT: li a0, 23
+; RV64-NEXT: sw a0, 16(sp)
+; RV64-NEXT: lui a0, 40
+; RV64-NEXT: addiw a0, a0, 103
+; RV64-NEXT: sw a0, 36(sp)
+; RV64-NEXT: lui a0, 5348
+; RV64-NEXT: addiw a0, a0, -509
+; RV64-NEXT: sw a0, 28(sp)
+; RV64-NEXT: lui a0, 7395
+; RV64-NEXT: addiw a0, a0, 643
+; RV64-NEXT: sw a0, 24(sp)
+; RV64-NEXT: lui a0, 1
+; RV64-NEXT: addiw a0, a0, -489
+; RV64-NEXT: sw a0, 20(sp)
+; RV64-NEXT: addi a1, sp, 40
+; RV64-NEXT: addi a0, sp, 16
+; RV64-NEXT: sd a0, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT: call __clear_cache
+; RV64-NEXT: ld a0, 0(sp) # 8-byte Folded Reload
+; RV64-NEXT: ld a1, 8(sp) # 8-byte Folded Reload
+; RV64-NEXT: jalr a1
+; RV64-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
+; RV64-NEXT: addi sp, sp, 64
+; RV64-NEXT: ret
+;
+; RV64-LINUX-LABEL: test0:
+; RV64-LINUX: # %bb.0:
+; RV64-LINUX-NEXT: lpad 0
+; RV64-LINUX-NEXT: addi sp, sp, -64
+; RV64-LINUX-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
+; RV64-LINUX-NEXT: sd a0, 0(sp) # 8-byte Folded Spill
+; RV64-LINUX-NEXT: lui a0, %hi(f)
+; RV64-LINUX-NEXT: addi a0, a0, %lo(f)
+; RV64-LINUX-NEXT: sd a0, 48(sp)
+; RV64-LINUX-NEXT: sd a1, 40(sp)
+; RV64-LINUX-NEXT: li a0, 951
+; RV64-LINUX-NEXT: sw a0, 32(sp)
+; RV64-LINUX-NEXT: li a0, 23
+; RV64-LINUX-NEXT: sw a0, 16(sp)
+; RV64-LINUX-NEXT: lui a0, 40
+; RV64-LINUX-NEXT: addiw a0, a0, 103
+; RV64-LINUX-NEXT: sw a0, 36(sp)
+; RV64-LINUX-NEXT: lui a0, 5348
+; RV64-LINUX-NEXT: addiw a0, a0, -509
+; RV64-LINUX-NEXT: sw a0, 28(sp)
+; RV64-LINUX-NEXT: lui a0, 7395
+; RV64-LINUX-NEXT: addiw a0, a0, 643
+; RV64-LINUX-NEXT: sw a0, 24(sp)
+; RV64-LINUX-NEXT: lui a0, 1
+; RV64-LINUX-NEXT: addiw a0, a0, -489
+; RV64-LINUX-NEXT: sw a0, 20(sp)
+; RV64-LINUX-NEXT: addi a1, sp, 40
+; RV64-LINUX-NEXT: addi a0, sp, 16
+; RV64-LINUX-NEXT: sd a0, 8(sp) # 8-byte Folded Spill
+; RV64-LINUX-NEXT: li a2, 0
+; RV64-LINUX-NEXT: call __riscv_flush_icache
+; RV64-LINUX-NEXT: ld a0, 0(sp) # 8-byte Folded Reload
+; RV64-LINUX-NEXT: ld a1, 8(sp) # 8-byte Folded Reload
+; RV64-LINUX-NEXT: jalr a1
+; RV64-LINUX-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
+; RV64-LINUX-NEXT: addi sp, sp, 64
+; RV64-LINUX-NEXT: ret
+ %alloca = alloca [40 x i8], align 8
+ call void @llvm.init.trampoline(ptr %alloca, ptr @f, ptr %p)
+ %tramp = call ptr @llvm.adjust.trampoline(ptr %alloca)
+ %ret = call i64 %tramp(i64 %n)
+ ret i64 %ret
+}
+
+!llvm.module.flags = !{!0}
+
+!0 = !{i32 8, !"cf-protection-branch", i32 1}
More information about the llvm-commits
mailing list