[llvm] [AArch64][PAC] Lower jump-tables using hardened pseudo. (PR #97666)
Ahmed Bougacha via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 3 20:17:51 PDT 2024
https://github.com/ahmedbougacha created https://github.com/llvm/llvm-project/pull/97666
This introduces an alternative hardened lowering for jump-table dispatch, controlled by the function attribute "jump-table-hardening". The implementation is centered around a pseudo; quoting its description:
> A hardened but more expensive version of jump-table dispatch.
> This combines the target address computation (otherwise done using the
> JumpTableDest pseudos above) with the branch itself (otherwise done using
> a plain BR) in a single non-attackable sequence.
>
> We take the final entry index as an operand to allow isel freedom. This does
> mean that the index can be attacker-controlled. To address that, we also do
> limited checking of the offset, mainly ensuring it still points within the
> jump-table array. When it doesn't, this branches to the first entry.
>
> This is intended for use in conjunction with ptrauth for other code pointers,
> to avoid signing jump-table entries and turning them into pointers.
>
> Entry index is passed in x16. Clobbers x16/x17/nzcv.
Jump-table compression isn't supported in this patch.
We can add it relatively easily in a separate change.
>From 6b7eb51ee224e3419103b0484570c985250faaf1 Mon Sep 17 00:00:00 2001
From: Ahmed Bougacha <ahmed at bougacha.org>
Date: Mon, 27 Sep 2021 08:00:00 -0700
Subject: [PATCH] [AArch64][PAC] Lower jump-tables using hardened pseudo.
---
llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp | 141 ++++++++++++++++++
.../Target/AArch64/AArch64ISelLowering.cpp | 15 ++
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 26 ++++
.../GISel/AArch64InstructionSelector.cpp | 14 +-
.../CodeGen/AArch64/hardened-jump-table-br.ll | 53 +++++++
5 files changed, 248 insertions(+), 1 deletion(-)
create mode 100644 llvm/test/CodeGen/AArch64/hardened-jump-table-br.ll
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 64d41d4147644..79d3f7e386fdf 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -104,6 +104,8 @@ class AArch64AsmPrinter : public AsmPrinter {
void LowerJumpTableDest(MCStreamer &OutStreamer, const MachineInstr &MI);
+ void LowerHardenedBRJumpTable(const MachineInstr &MI);
+
void LowerMOPS(MCStreamer &OutStreamer, const MachineInstr &MI);
void LowerSTACKMAP(MCStreamer &OutStreamer, StackMaps &SM,
@@ -1310,6 +1312,141 @@ void AArch64AsmPrinter::LowerJumpTableDest(llvm::MCStreamer &OutStreamer,
.addImm(Size == 4 ? 0 : 2));
}
+void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
+ unsigned InstsEmitted = 0;
+
+ const MachineJumpTableInfo *MJTI = MF->getJumpTableInfo();
+ assert(MJTI && "Can't lower jump-table dispatch without JTI");
+
+ const std::vector<MachineJumpTableEntry> &JTs = MJTI->getJumpTables();
+ assert(!JTs.empty() && "Invalid JT index for jump-table dispatch");
+
+ // Emit:
+ // mov x17, #<size of table> ; depending on table size, with MOVKs
+ // cmp x16, x17 ; or #imm if table size fits in 12-bit
+ // csel x16, x16, xzr, ls ; check for index overflow
+ //
+ // adrp x17, Ltable at PAGE ; materialize table address
+ // add x17, Ltable at PAGEOFF
+ // ldrsw x16, [x17, x16, lsl #2] ; load table entry
+ //
+ // Lanchor:
+ // adr x17, Lanchor ; compute target address
+ // add x16, x17, x16
+ // br x16 ; branch to target
+
+ MachineOperand JTOp = MI.getOperand(0);
+
+ unsigned JTI = JTOp.getIndex();
+ assert(!AArch64FI->getJumpTableEntryPCRelSymbol(JTI) &&
+ "unsupported compressed jump table");
+
+ const uint64_t NumTableEntries = JTs[JTI].MBBs.size();
+
+ // cmp only supports a 12-bit immediate. If we need more, materialize the
+ // immediate, using x17 as a scratch register.
+ uint64_t MaxTableEntry = NumTableEntries - 1;
+ if (isUInt<12>(MaxTableEntry)) {
+ EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::SUBSXri)
+ .addReg(AArch64::XZR)
+ .addReg(AArch64::X16)
+ .addImm(MaxTableEntry)
+ .addImm(0));
+ ++InstsEmitted;
+ } else {
+ EmitToStreamer(*OutStreamer,
+ MCInstBuilder(AArch64::MOVZXi)
+ .addReg(AArch64::X17)
+ .addImm(static_cast<uint16_t>(MaxTableEntry))
+ .addImm(0));
+ ++InstsEmitted;
+ // It's sad that we have to manually materialize instructions, but we can't
+ // trivially reuse the main pseudo expansion logic.
+ // A MOVK sequence is easy enough to generate and handles the general case.
+ for (int Offset = 16; Offset < 64; Offset += 16) {
+ if ((MaxTableEntry >> Offset) == 0)
+ break;
+ EmitToStreamer(*OutStreamer,
+ MCInstBuilder(AArch64::MOVKXi)
+ .addReg(AArch64::X17)
+ .addReg(AArch64::X17)
+ .addImm(static_cast<uint16_t>(MaxTableEntry >> Offset))
+ .addImm(Offset));
+ ++InstsEmitted;
+ }
+ EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::SUBSXrs)
+ .addReg(AArch64::XZR)
+ .addReg(AArch64::X16)
+ .addReg(AArch64::X17)
+ .addImm(0));
+ ++InstsEmitted;
+ }
+
+ // This picks entry #0 on failure.
+ // We might want to trap instead.
+ EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::CSELXr)
+ .addReg(AArch64::X16)
+ .addReg(AArch64::X16)
+ .addReg(AArch64::XZR)
+ .addImm(AArch64CC::LS));
+ ++InstsEmitted;
+
+ // Prepare the @PAGE/@PAGEOFF low/high operands.
+ MachineOperand JTMOHi(JTOp), JTMOLo(JTOp);
+ MCOperand JTMCHi, JTMCLo;
+
+ JTMOHi.setTargetFlags(AArch64II::MO_PAGE);
+ JTMOLo.setTargetFlags(AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
+
+ MCInstLowering.lowerOperand(JTMOHi, JTMCHi);
+ MCInstLowering.lowerOperand(JTMOLo, JTMCLo);
+
+ EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADRP)
+ .addReg(AArch64::X17)
+ .addOperand(JTMCHi));
+ ++InstsEmitted;
+
+ EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXri)
+ .addReg(AArch64::X17)
+ .addReg(AArch64::X17)
+ .addOperand(JTMCLo)
+ .addImm(0));
+ ++InstsEmitted;
+
+
+ EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::LDRSWroX)
+ .addReg(AArch64::X16)
+ .addReg(AArch64::X17)
+ .addReg(AArch64::X16)
+ .addImm(0)
+ .addImm(1));
+ ++InstsEmitted;
+
+ MCSymbol *AdrLabel = MF->getContext().createTempSymbol();
+ auto *AdrLabelE = MCSymbolRefExpr::create(AdrLabel, MF->getContext());
+ AArch64FI->setJumpTableEntryInfo(JTI, 4, AdrLabel);
+
+ OutStreamer->emitLabel(AdrLabel);
+ EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADR)
+ .addReg(AArch64::X17)
+ .addExpr(AdrLabelE));
+ ++InstsEmitted;
+
+ EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXrs)
+ .addReg(AArch64::X16)
+ .addReg(AArch64::X17)
+ .addReg(AArch64::X16)
+ .addImm(0));
+ ++InstsEmitted;
+
+ EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::BR)
+ .addReg(AArch64::X16));
+ ++InstsEmitted;
+
+ assert(STI->getInstrInfo()->getInstSizeInBytes(MI) >= InstsEmitted * 4);
+}
+
+
void AArch64AsmPrinter::LowerMOPS(llvm::MCStreamer &OutStreamer,
const llvm::MachineInstr &MI) {
unsigned Opcode = MI.getOpcode();
@@ -2177,6 +2314,10 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
LowerJumpTableDest(*OutStreamer, *MI);
return;
+ case AArch64::BR_JumpTable:
+ LowerHardenedBRJumpTable(*MI);
+ return;
+
case AArch64::FMOVH0:
case AArch64::FMOVS0:
case AArch64::FMOVD0:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e0c3cc5eddb82..4248c7e8d6c60 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -10678,6 +10678,21 @@ SDValue AArch64TargetLowering::LowerBR_JT(SDValue Op,
auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
AFI->setJumpTableEntryInfo(JTI, 4, nullptr);
+ // With jump-table-hardening, we only expand the full jump table dispatch
+ // sequence later, to guarantee the integrity of the intermediate values.
+ if (DAG.getMachineFunction().getFunction()
+ .hasFnAttribute("jump-table-hardening") ||
+ Subtarget->getTargetTriple().isArm64e()) {
+ assert(Subtarget->isTargetMachO() &&
+ "hardened jump-table not yet supported on non-macho");
+ SDValue X16Copy = DAG.getCopyToReg(DAG.getEntryNode(), DL, AArch64::X16,
+ Entry, SDValue());
+ SDNode *B = DAG.getMachineNode(AArch64::BR_JumpTable, DL, MVT::Other,
+ DAG.getTargetJumpTable(JTI, MVT::i32),
+ X16Copy.getValue(0), X16Copy.getValue(1));
+ return SDValue(B, 0);
+ }
+
SDNode *Dest =
DAG.getMachineNode(AArch64::JumpTableDest32, DL, MVT::i64, MVT::i64, JT,
Entry, DAG.getTargetJumpTable(JTI, MVT::i32));
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 1e06d5fdc7562..0c43851ac121b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1143,6 +1143,32 @@ def JumpTableDest8 : Pseudo<(outs GPR64:$dst, GPR64sp:$scratch),
Sched<[]>;
}
+// A hardened but more expensive version of jump-table dispatch.
+// This combines the target address computation (otherwise done using the
+// JumpTableDest pseudos above) with the branch itself (otherwise done using
+// a plain BR) in a single non-attackable sequence.
+//
+// We take the final entry index as an operand to allow isel freedom. This does
+// mean that the index can be attacker-controlled. To address that, we also do
+// limited checking of the offset, mainly ensuring it still points within the
+// jump-table array. When it doesn't, this branches to the first entry.
+//
+// This is intended for use in conjunction with ptrauth for other code pointers,
+// to avoid signing jump-table entries and turning them into pointers.
+//
+// Entry index is passed in x16. Clobbers x16/x17/nzcv.
+let isNotDuplicable = 1 in
+def BR_JumpTable : Pseudo<(outs), (ins i32imm:$jti), []>, Sched<[]> {
+ let isBranch = 1;
+ let isTerminator = 1;
+ let isIndirectBranch = 1;
+ let isBarrier = 1;
+ let isNotDuplicable = 1;
+ let Defs = [X16,X17,NZCV];
+ let Uses = [X16];
+ let Size = 44; // 28 fixed + 16 variable, for table size materialization
+}
+
// Space-consuming pseudo to aid testing of placement and reachability
// algorithms. Immediate operand is the number of bytes this "instruction"
// occupies; register operands can be used to enforce dependency and constrain
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 9e0860934f777..b6e8ffac0a6d2 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -3597,10 +3597,22 @@ bool AArch64InstructionSelector::selectBrJT(MachineInstr &I,
unsigned JTI = I.getOperand(1).getIndex();
Register Index = I.getOperand(2).getReg();
+ MF->getInfo<AArch64FunctionInfo>()->setJumpTableEntryInfo(JTI, 4, nullptr);
+ if (MF->getFunction().hasFnAttribute("jump-table-hardening") ||
+ STI.getTargetTriple().isArm64e()) {
+ if (TM.getCodeModel() != CodeModel::Small)
+ report_fatal_error("Unsupported code-model for hardened jump-table");
+
+ MIB.buildCopy({AArch64::X16}, I.getOperand(2).getReg());
+ MIB.buildInstr(AArch64::BR_JumpTable)
+ .addJumpTableIndex(I.getOperand(1).getIndex());
+ I.eraseFromParent();
+ return true;
+ }
+
Register TargetReg = MRI.createVirtualRegister(&AArch64::GPR64RegClass);
Register ScratchReg = MRI.createVirtualRegister(&AArch64::GPR64spRegClass);
- MF->getInfo<AArch64FunctionInfo>()->setJumpTableEntryInfo(JTI, 4, nullptr);
auto JumpTableInst = MIB.buildInstr(AArch64::JumpTableDest32,
{TargetReg, ScratchReg}, {JTAddr, Index})
.addJumpTableIndex(JTI);
diff --git a/llvm/test/CodeGen/AArch64/hardened-jump-table-br.ll b/llvm/test/CodeGen/AArch64/hardened-jump-table-br.ll
new file mode 100644
index 0000000000000..fa71b15b285aa
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/hardened-jump-table-br.ll
@@ -0,0 +1,53 @@
+; RUN: llc -verify-machineinstrs -o - %s -mtriple=arm64-apple-ios -aarch64-min-jump-table-entries=1 -aarch64-enable-atomic-cfg-tidy=0 | FileCheck %s
+; RUN: llc -verify-machineinstrs -o - %s -mtriple=arm64-apple-ios -aarch64-min-jump-table-entries=1 -aarch64-enable-atomic-cfg-tidy=0 -code-model=large | FileCheck %s
+; RUN: llc -verify-machineinstrs -o - %s -mtriple=arm64-apple-ios -aarch64-min-jump-table-entries=1 -aarch64-enable-atomic-cfg-tidy=0 -global-isel -global-isel-abort=1 | FileCheck %s
+
+; CHECK-LABEL: test_jumptable:
+; CHECK: mov w[[INDEX:[0-9]+]], w0
+; CHECK: cmp x[[INDEX]], #5
+; CHECK: csel [[INDEX2:x[0-9]+]], x[[INDEX]], xzr, ls
+; CHECK-NEXT: adrp [[JTPAGE:x[0-9]+]], LJTI0_0 at PAGE
+; CHECK-NEXT: add x[[JT:[0-9]+]], [[JTPAGE]], LJTI0_0 at PAGEOFF
+; CHECK-NEXT: ldrsw [[OFFSET:x[0-9]+]], [x[[JT]], [[INDEX2]], lsl #2]
+; CHECK-NEXT: Ltmp0:
+; CHECK-NEXT: adr [[TABLE:x[0-9]+]], Ltmp0
+; CHECK-NEXT: add [[DEST:x[0-9]+]], [[TABLE]], [[OFFSET]]
+; CHECK-NEXT: br [[DEST]]
+
+define i32 @test_jumptable(i32 %in) "jump-table-hardening" {
+
+ switch i32 %in, label %def [
+ i32 0, label %lbl1
+ i32 1, label %lbl2
+ i32 2, label %lbl3
+ i32 4, label %lbl4
+ i32 5, label %lbl5
+ ]
+
+def:
+ ret i32 0
+
+lbl1:
+ ret i32 1
+
+lbl2:
+ ret i32 2
+
+lbl3:
+ ret i32 4
+
+lbl4:
+ ret i32 8
+
+lbl5:
+ ret i32 10
+
+}
+
+; CHECK: LJTI0_0:
+; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
+; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
+; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
+; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
+; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
+; CHECK-NEXT: .long LBB{{[0-9_]+}}-Ltmp0
More information about the llvm-commits
mailing list