[llvm] e348534 - [RISC-V][HWASAN] Add support for lowering HWASAN intrinsic for RISC-V

Alexey Baturo via llvm-commits llvm-commits at lists.llvm.org
Sun Aug 28 11:23:25 PDT 2022


Author: Alexey Baturo
Date: 2022-08-28T21:22:13+03:00
New Revision: e3485345d30c4eded8b983810b2e07dbc300e2b3

URL: https://github.com/llvm/llvm-project/commit/e3485345d30c4eded8b983810b2e07dbc300e2b3
DIFF: https://github.com/llvm/llvm-project/commit/e3485345d30c4eded8b983810b2e07dbc300e2b3.diff

LOG: [RISC-V][HWASAN] Add support for lowering HWASAN intrinsic for RISC-V

Reviewed By: vitalybuka

Differential Revision: https://reviews.llvm.org/D131343

Added: 
    llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll

Modified: 
    llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
index 3f0584821fe3e..60b908ba2455d 100644
--- a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
+++ b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
@@ -18,17 +18,24 @@
 #include "RISCVTargetMachine.h"
 #include "TargetInfo/RISCVTargetInfo.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/BinaryFormat/ELF.h"
 #include "llvm/CodeGen/AsmPrinter.h"
 #include "llvm/CodeGen/MachineConstantPool.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/CodeGen/MachineInstr.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/MC/MCAsmInfo.h"
+#include "llvm/MC/MCContext.h"
 #include "llvm/MC/MCInst.h"
+#include "llvm/MC/MCInstBuilder.h"
+#include "llvm/MC/MCObjectFileInfo.h"
+#include "llvm/MC/MCSectionELF.h"
 #include "llvm/MC/MCStreamer.h"
 #include "llvm/MC/MCSymbol.h"
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Instrumentation/HWAddressSanitizer.h"
+
 using namespace llvm;
 
 #define DEBUG_TYPE "asm-printer"
@@ -61,6 +68,11 @@ class RISCVAsmPrinter : public AsmPrinter {
   bool emitPseudoExpansionLowering(MCStreamer &OutStreamer,
                                    const MachineInstr *MI);
 
+  typedef std::tuple<unsigned, uint32_t> HwasanMemaccessTuple;
+  std::map<HwasanMemaccessTuple, MCSymbol *> HwasanMemaccessSymbols;
+  void LowerHWASAN_CHECK_MEMACCESS(const MachineInstr &MI);
+  void EmitHwasanMemaccessSymbols(Module &M);
+
   // Wrapper needed for tblgenned pseudo lowering.
   bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp) const {
     return lowerRISCVMachineOperandToMCOperand(MO, MCOp, *this);
@@ -97,6 +109,12 @@ void RISCVAsmPrinter::emitInstruction(const MachineInstr *MI) {
     return;
 
   MCInst TmpInst;
+
+  if (MI->getOpcode() == RISCV::HWASAN_CHECK_MEMACCESS_SHORTGRANULES) {
+    LowerHWASAN_CHECK_MEMACCESS(*MI);
+    return;
+  }
+
   if (!lowerRISCVMachineInstrToMCInst(MI, TmpInst, *this))
     EmitToStreamer(*OutStreamer, TmpInst);
 }
@@ -198,6 +216,7 @@ void RISCVAsmPrinter::emitEndOfAsmFile(Module &M) {
 
   if (TM.getTargetTriple().isOSBinFormatELF())
     RTS.finishAttributeSection();
+  EmitHwasanMemaccessSymbols(M);
 }
 
 void RISCVAsmPrinter::emitAttributes() {
@@ -211,3 +230,253 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVAsmPrinter() {
   RegisterAsmPrinter<RISCVAsmPrinter> X(getTheRISCV32Target());
   RegisterAsmPrinter<RISCVAsmPrinter> Y(getTheRISCV64Target());
 }
+
+void RISCVAsmPrinter::LowerHWASAN_CHECK_MEMACCESS(const MachineInstr &MI) {
+  Register Reg = MI.getOperand(0).getReg();
+  uint32_t AccessInfo = MI.getOperand(1).getImm();
+  MCSymbol *&Sym =
+      HwasanMemaccessSymbols[HwasanMemaccessTuple(Reg, AccessInfo)];
+  if (!Sym) {
+    // FIXME: Make this work on non-ELF.
+    if (!TM.getTargetTriple().isOSBinFormatELF())
+      report_fatal_error("llvm.hwasan.check.memaccess only supported on ELF");
+
+    std::string SymName = "__hwasan_check_x" + utostr(Reg - RISCV::X0) + "_" +
+                          utostr(AccessInfo) + "_short";
+    Sym = OutContext.getOrCreateSymbol(SymName);
+  }
+  auto Res = MCSymbolRefExpr::create(Sym, MCSymbolRefExpr::VK_None, OutContext);
+  auto Expr = RISCVMCExpr::create(Res, RISCVMCExpr::VK_RISCV_CALL, OutContext);
+
+  EmitToStreamer(*OutStreamer, MCInstBuilder(RISCV::PseudoCALL).addExpr(Expr));
+}
+
+void RISCVAsmPrinter::EmitHwasanMemaccessSymbols(Module &M) {
+  if (HwasanMemaccessSymbols.empty())
+    return;
+
+  const Triple &TT = TM.getTargetTriple();
+  assert(TT.isOSBinFormatELF());
+
+  MCSymbol *HwasanTagMismatchV2Sym =
+      OutContext.getOrCreateSymbol("__hwasan_tag_mismatch_v2");
+
+  const MCSymbolRefExpr *HwasanTagMismatchV2Ref =
+      MCSymbolRefExpr::create(HwasanTagMismatchV2Sym, OutContext);
+
+  for (auto &P : HwasanMemaccessSymbols) {
+    unsigned Reg = std::get<0>(P.first);
+    uint32_t AccessInfo = std::get<1>(P.first);
+    const MCSymbolRefExpr *HwasanTagMismatchRef = HwasanTagMismatchV2Ref;
+    MCSymbol *Sym = P.second;
+
+    unsigned Size =
+        1 << ((AccessInfo >> HWASanAccessInfo::AccessSizeShift) & 0xf);
+    OutStreamer->switchSection(OutContext.getELFSection(
+        ".text.hot", ELF::SHT_PROGBITS,
+        ELF::SHF_EXECINSTR | ELF::SHF_ALLOC | ELF::SHF_GROUP, 0, Sym->getName(),
+        /*IsComdat=*/true));
+
+    OutStreamer->emitSymbolAttribute(Sym, MCSA_ELF_TypeFunction);
+    OutStreamer->emitSymbolAttribute(Sym, MCSA_Weak);
+    OutStreamer->emitSymbolAttribute(Sym, MCSA_Hidden);
+    OutStreamer->emitLabel(Sym);
+
+    // Extract shadow offset from ptr
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::SLLI).addReg(RISCV::X6).addReg(Reg).addImm(8),
+        *STI);
+    OutStreamer->emitInstruction(MCInstBuilder(RISCV::SRLI)
+                                     .addReg(RISCV::X6)
+                                     .addReg(RISCV::X6)
+                                     .addImm(12),
+                                 *STI);
+    // load shadow tag in X6, X5 contains shadow base
+    OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADD)
+                                     .addReg(RISCV::X6)
+                                     .addReg(RISCV::X5)
+                                     .addReg(RISCV::X6),
+                                 *STI);
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::LBU).addReg(RISCV::X6).addReg(RISCV::X6).addImm(0),
+        *STI);
+    // Extract tag from X5 and compare it with loaded tag from shadow
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::SRLI).addReg(RISCV::X7).addReg(Reg).addImm(56),
+        *STI);
+    MCSymbol *HandleMismatchOrPartialSym = OutContext.createTempSymbol();
+    // X7 contains tag from memory, while X6 contains tag from the pointer
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::BNE)
+            .addReg(RISCV::X7)
+            .addReg(RISCV::X6)
+            .addExpr(MCSymbolRefExpr::create(HandleMismatchOrPartialSym,
+                                             OutContext)),
+        *STI);
+    MCSymbol *ReturnSym = OutContext.createTempSymbol();
+    OutStreamer->emitLabel(ReturnSym);
+    OutStreamer->emitInstruction(MCInstBuilder(RISCV::JALR)
+                                     .addReg(RISCV::X0)
+                                     .addReg(RISCV::X1)
+                                     .addImm(0),
+                                 *STI);
+    OutStreamer->emitLabel(HandleMismatchOrPartialSym);
+
+    OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADDI)
+                                     .addReg(RISCV::X28)
+                                     .addReg(RISCV::X0)
+                                     .addImm(16),
+                                 *STI);
+    MCSymbol *HandleMismatchSym = OutContext.createTempSymbol();
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::BGEU)
+            .addReg(RISCV::X6)
+            .addReg(RISCV::X28)
+            .addExpr(MCSymbolRefExpr::create(HandleMismatchSym, OutContext)),
+        *STI);
+
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::ANDI).addReg(RISCV::X28).addReg(Reg).addImm(0xF),
+        *STI);
+
+    if (Size != 1)
+      OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADDI)
+                                       .addReg(RISCV::X28)
+                                       .addReg(RISCV::X28)
+                                       .addImm(Size - 1),
+                                   *STI);
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::BGE)
+            .addReg(RISCV::X28)
+            .addReg(RISCV::X6)
+            .addExpr(MCSymbolRefExpr::create(HandleMismatchSym, OutContext)),
+        *STI);
+
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::ORI).addReg(RISCV::X6).addReg(Reg).addImm(0xF),
+        *STI);
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::LBU).addReg(RISCV::X6).addReg(RISCV::X6).addImm(0),
+        *STI);
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::BEQ)
+            .addReg(RISCV::X6)
+            .addReg(RISCV::X7)
+            .addExpr(MCSymbolRefExpr::create(ReturnSym, OutContext)),
+        *STI);
+
+    OutStreamer->emitLabel(HandleMismatchSym);
+
+    // | Previous stack frames...        |
+    // +=================================+ <-- [SP + 256]
+    // |              ...                |
+    // |                                 |
+    // | Stack frame space for x12 - x31.|
+    // |                                 |
+    // |              ...                |
+    // +---------------------------------+ <-- [SP + 96]
+    // | Saved x11(arg1), as             |
+    // | __hwasan_check_* clobbers it.   |
+    // +---------------------------------+ <-- [SP + 88]
+    // | Saved x10(arg0), as             |
+    // | __hwasan_check_* clobbers it.   |
+    // +---------------------------------+ <-- [SP + 80]
+    // |                                 |
+    // | Stack frame space for x9.       |
+    // +---------------------------------+ <-- [SP + 72]
+    // |                                 |
+    // | Saved x8(fp), as                |
+    // | __hwasan_check_* clobbers it.   |
+    // +---------------------------------+ <-- [SP + 64]
+    // |              ...                |
+    // |                                 |
+    // | Stack frame space for x2 - x7.  |
+    // |                                 |
+    // |              ...                |
+    // +---------------------------------+ <-- [SP + 16]
+    // | Return address (x1) for caller  |
+    // | of __hwasan_check_*.            |
+    // +---------------------------------+ <-- [SP + 8]
+    // | Reserved place for x0, possibly |
+    // | junk, since we don't save it.   |
+    // +---------------------------------+ <-- [x2 / SP]
+
+    // Adjust sp
+    OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADDI)
+                                     .addReg(RISCV::X2)
+                                     .addReg(RISCV::X2)
+                                     .addImm(-256),
+                                 *STI);
+
+    // store x10(arg0) by new sp
+    OutStreamer->emitInstruction(MCInstBuilder(RISCV::SD)
+                                     .addReg(RISCV::X10)
+                                     .addReg(RISCV::X2)
+                                     .addImm(8 * 10),
+                                 *STI);
+    // store x11(arg1) by new sp
+    OutStreamer->emitInstruction(MCInstBuilder(RISCV::SD)
+                                     .addReg(RISCV::X11)
+                                     .addReg(RISCV::X2)
+                                     .addImm(8 * 11),
+                                 *STI);
+
+    // store x8(fp) by new sp
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::SD).addReg(RISCV::X8).addReg(RISCV::X2).addImm(8 *
+                                                                            8),
+        *STI);
+    // store x1(ra) by new sp
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::SD).addReg(RISCV::X1).addReg(RISCV::X2).addImm(1 *
+                                                                            8),
+        *STI);
+    if (Reg != RISCV::X10)
+      OutStreamer->emitInstruction(MCInstBuilder(RISCV::OR)
+                                       .addReg(RISCV::X10)
+                                       .addReg(RISCV::X0)
+                                       .addReg(Reg),
+                                   *STI);
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::ADDI)
+            .addReg(RISCV::X11)
+            .addReg(RISCV::X0)
+            .addImm(AccessInfo & HWASanAccessInfo::RuntimeMask),
+        *STI);
+
+    // Intentionally load the GOT entry and branch to it, rather than possibly
+    // late binding the function, which may clobber the registers before we have
+    // a chance to save them.
+    RISCVMCExpr::VariantKind VKHi;
+    unsigned SecondOpcode;
+    if (OutContext.getObjectFileInfo()->isPositionIndependent()) {
+      SecondOpcode = RISCV::LD;
+      VKHi = RISCVMCExpr::VK_RISCV_GOT_HI;
+    } else {
+      SecondOpcode = RISCV::ADDI;
+      VKHi = RISCVMCExpr::VK_RISCV_PCREL_HI;
+    }
+    auto ExprHi = RISCVMCExpr::create(HwasanTagMismatchRef, VKHi, OutContext);
+
+    MCSymbol *TmpLabel =
+        OutContext.createTempSymbol("pcrel_hi", /* AlwaysAddSuffix */ true);
+    OutStreamer->emitLabel(TmpLabel);
+    const MCExpr *ExprLo =
+        RISCVMCExpr::create(MCSymbolRefExpr::create(TmpLabel, OutContext),
+                            RISCVMCExpr::VK_RISCV_PCREL_LO, OutContext);
+
+    OutStreamer->emitInstruction(
+        MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X6).addExpr(ExprHi), *STI);
+    OutStreamer->emitInstruction(MCInstBuilder(SecondOpcode)
+                                     .addReg(RISCV::X6)
+                                     .addReg(RISCV::X6)
+                                     .addExpr(ExprLo),
+                                 *STI);
+
+    OutStreamer->emitInstruction(MCInstBuilder(RISCV::JALR)
+                                     .addReg(RISCV::X0)
+                                     .addReg(RISCV::X6)
+                                     .addImm(0),
+                                 *STI);
+  }
+}

diff  --git a/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll b/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll
new file mode 100644
index 0000000000000..f4328115dc2c3
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll
@@ -0,0 +1,60 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 < %s | FileCheck --check-prefixes=CHECK,NOPIC %s
+; RUN: llc -mtriple=riscv64 --relocation-model=pic < %s | FileCheck --check-prefixes=CHECK,PIC %s
+
+define i8* @f2(i8* %x0, i8* %x1) {
+; CHECK-LABEL: f2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi sp, sp, -16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_offset ra, -8
+; CHECK-NEXT:    mv t0, a1
+; CHECK-NEXT:    call __hwasan_check_x10_2_short
+; CHECK-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; CHECK-NEXT:    addi sp, sp, 16
+; CHECK-NEXT:    ret
+  call void @llvm.hwasan.check.memaccess.shortgranules(i8* %x1, i8* %x0, i32 2)
+  ret i8* %x0
+}
+
+declare void @llvm.hwasan.check.memaccess.shortgranules(i8*, i8*, i32)
+
+; CHECK: .section        .text.hot,"axG", at progbits,__hwasan_check_x10_2_short,comdat
+; CHECK-NEXT: .type   __hwasan_check_x10_2_short, at function
+; CHECK-NEXT: .weak   __hwasan_check_x10_2_short
+; CHECK-NEXT: .hidden __hwasan_check_x10_2_short
+; CHECK-NEXT: __hwasan_check_x10_2_short:
+; CHECK-NEXT: slli    t1, a0, 8
+; CHECK-NEXT: srli    t1, t1, 12
+; CHECK-NEXT: add     t1, t0, t1
+; CHECK-NEXT: lbu     t1, 0(t1)
+; CHECK-NEXT: srli    t2, a0, 56
+; CHECK-NEXT: bne     t2, t1, .Ltmp0
+; CHECK-NEXT: .Ltmp1:
+; CHECK-NEXT: ret
+; CHECK-NEXT: .Ltmp0:
+; CHECK-NEXT: li      t3, 16
+; CHECK-NEXT: bgeu    t1, t3, .Ltmp2
+; CHECK-NEXT: andi    t3, a0, 15
+; CHECK-NEXT: addi    t3, t3, 3
+; CHECK-NEXT: bge     t3, t1, .Ltmp2
+; CHECK-NEXT: ori     t1, a0, 15
+; CHECK-NEXT: lbu     t1, 0(t1)
+; CHECK-NEXT: beq     t1, t2, .Ltmp1
+; CHECK-NEXT: .Ltmp2:
+; CHECK-NEXT: addi    sp, sp, -256
+; CHECK-NEXT: sd      a0, 80(sp)
+; CHECK-NEXT: sd      a1, 88(sp)
+; CHECK-NEXT: sd      s0, 64(sp)
+; CHECK-NEXT: sd      ra, 8(sp)
+; CHECK-NEXT: li      a1, 2
+; CHECK-NEXT: .Lpcrel_hi0:
+; NOPIC-NEXT: auipc   t1, %pcrel_hi(__hwasan_tag_mismatch_v2)
+; NOPIC-NEXT: addi    t1, t1, %pcrel_lo(.Lpcrel_hi0)
+; PIC-NEXT: auipc   t1, %got_pcrel_hi(__hwasan_tag_mismatch_v2)
+; PIC-NEXT: ld      t1, %pcrel_lo(.Lpcrel_hi0)(t1)
+; CHECK-NEXT: jr      t1
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; NOPIC: {{.*}}
+; PIC: {{.*}}


        


More information about the llvm-commits mailing list