[llvm] [AArch64][PAC] Lower jump-tables using hardened pseudo. (PR #97666)

Jon Roelofs via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 3 21:42:12 PDT 2024


================
@@ -1310,6 +1312,138 @@ void AArch64AsmPrinter::LowerJumpTableDest(llvm::MCStreamer &OutStreamer,
                                   .addImm(Size == 4 ? 0 : 2));
 }
 
+void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
+  unsigned InstsEmitted = 0;
+
+  const MachineJumpTableInfo *MJTI = MF->getJumpTableInfo();
+  assert(MJTI && "Can't lower jump-table dispatch without JTI");
+
+  const std::vector<MachineJumpTableEntry> &JTs = MJTI->getJumpTables();
+  assert(!JTs.empty() && "Invalid JT index for jump-table dispatch");
+
+  // Emit:
+  //     mov x17, #<size of table>     ; depending on table size, with MOVKs
+  //     cmp x16, x17                  ; or #imm if table size fits in 12-bit
+  //     csel x16, x16, xzr, ls        ; check for index overflow
+  //
+  //     adrp x17, Ltable at PAGE         ; materialize table address
+  //     add x17, Ltable at PAGEOFF
+  //     ldrsw x16, [x17, x16, lsl #2] ; load table entry
+  //
+  //   Lanchor:
+  //     adr x17, Lanchor              ; compute target address
+  //     add x16, x17, x16
+  //     br x16                        ; branch to target
+
+  MachineOperand JTOp = MI.getOperand(0);
+
+  unsigned JTI = JTOp.getIndex();
+  assert(!AArch64FI->getJumpTableEntryPCRelSymbol(JTI) &&
+         "unsupported compressed jump table");
+
+  const uint64_t NumTableEntries = JTs[JTI].MBBs.size();
+
+  // cmp only supports a 12-bit immediate.  If we need more, materialize the
+  // immediate, using x17 as a scratch register.
+  uint64_t MaxTableEntry = NumTableEntries - 1;
+  if (isUInt<12>(MaxTableEntry)) {
+    EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::SUBSXri)
+                                     .addReg(AArch64::XZR)
+                                     .addReg(AArch64::X16)
+                                     .addImm(MaxTableEntry)
+                                     .addImm(0));
+    ++InstsEmitted;
+  } else {
+    EmitToStreamer(*OutStreamer,
+                   MCInstBuilder(AArch64::MOVZXi)
+                       .addReg(AArch64::X17)
+                       .addImm(static_cast<uint16_t>(MaxTableEntry))
+                       .addImm(0));
+    ++InstsEmitted;
+    // It's sad that we have to manually materialize instructions, but we can't
+    // trivially reuse the main pseudo expansion logic.
+    // A MOVK sequence is easy enough to generate and handles the general case.
+    for (int Offset = 16; Offset < 64; Offset += 16) {
+      if ((MaxTableEntry >> Offset) == 0)
+        break;
+      EmitToStreamer(*OutStreamer,
+                     MCInstBuilder(AArch64::MOVKXi)
+                         .addReg(AArch64::X17)
+                         .addReg(AArch64::X17)
+                         .addImm(static_cast<uint16_t>(MaxTableEntry >> Offset))
+                         .addImm(Offset));
+      ++InstsEmitted;
+    }
+    EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::SUBSXrs)
+                                     .addReg(AArch64::XZR)
+                                     .addReg(AArch64::X16)
+                                     .addReg(AArch64::X17)
+                                     .addImm(0));
+    ++InstsEmitted;
+  }
+
+  // This picks entry #0 on failure.
+  // We might want to trap instead.
+  EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::CSELXr)
+                                   .addReg(AArch64::X16)
+                                   .addReg(AArch64::X16)
+                                   .addReg(AArch64::XZR)
+                                   .addImm(AArch64CC::LS));
+  ++InstsEmitted;
+
+  // Prepare the @PAGE/@PAGEOFF low/high operands.
+  MachineOperand JTMOHi(JTOp), JTMOLo(JTOp);
+  MCOperand JTMCHi, JTMCLo;
+
+  JTMOHi.setTargetFlags(AArch64II::MO_PAGE);
+  JTMOLo.setTargetFlags(AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
+
+  MCInstLowering.lowerOperand(JTMOHi, JTMCHi);
+  MCInstLowering.lowerOperand(JTMOLo, JTMCLo);
+
+  EmitToStreamer(
+      *OutStreamer,
+      MCInstBuilder(AArch64::ADRP).addReg(AArch64::X17).addOperand(JTMCHi));
+  ++InstsEmitted;
+
+  EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXri)
+                                   .addReg(AArch64::X17)
+                                   .addReg(AArch64::X17)
+                                   .addOperand(JTMCLo)
+                                   .addImm(0));
+  ++InstsEmitted;
+
+  EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::LDRSWroX)
+                                   .addReg(AArch64::X16)
+                                   .addReg(AArch64::X17)
+                                   .addReg(AArch64::X16)
+                                   .addImm(0)
+                                   .addImm(1));
+  ++InstsEmitted;
+
+  MCSymbol *AdrLabel = MF->getContext().createTempSymbol();
+  auto *AdrLabelE = MCSymbolRefExpr::create(AdrLabel, MF->getContext());
+  AArch64FI->setJumpTableEntryInfo(JTI, 4, AdrLabel);
+
+  OutStreamer->emitLabel(AdrLabel);
+  EmitToStreamer(
+      *OutStreamer,
+      MCInstBuilder(AArch64::ADR).addReg(AArch64::X17).addExpr(AdrLabelE));
+  ++InstsEmitted;
+
+  EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXrs)
+                                   .addReg(AArch64::X16)
+                                   .addReg(AArch64::X17)
+                                   .addReg(AArch64::X16)
+                                   .addImm(0));
+  ++InstsEmitted;
+
+  EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::BR).addReg(AArch64::X16));
+  ++InstsEmitted;
+
+  assert(STI->getInstrInfo()->getInstSizeInBytes(MI) >= InstsEmitted * 4);
----------------
jroelofs wrote:

(void)InstsEmitted;

https://github.com/llvm/llvm-project/pull/97666


More information about the llvm-commits mailing list