[llvm] [RISC-V][HWASAN] Allow disabling short granules (PR #103729)

Samuel Holland via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 10 11:14:51 PDT 2024


https://github.com/SiFiveHolland updated https://github.com/llvm/llvm-project/pull/103729

>From f3a50190e0fdb712a0aa7dd33eac7bc65350237b Mon Sep 17 00:00:00 2001
From: Samuel Holland <samuel.holland at sifive.com>
Date: Fri, 2 Aug 2024 16:30:09 -0700
Subject: [PATCH] [RISC-V][HWASAN] Allow disabling short granules

Linux kernel HWASAN does not use them. Currently, passing
hwasan-use-short-granules=0 causes an ICE due to the missing pattern for
int_hwasan_check_memaccess.
---
 llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp     | 110 ++++++++++--------
 llvm/lib/Target/RISCV/RISCVInstrInfo.td       |   7 ++
 .../CodeGen/RISCV/hwasan-check-memaccess.ll   | 102 +++++++++++++---
 3 files changed, 155 insertions(+), 64 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
index 5ad09ae7290fc5..da0c733f3e6f72 100644
--- a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
+++ b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
@@ -94,7 +94,7 @@ class RISCVAsmPrinter : public AsmPrinter {
 
   bool lowerPseudoInstExpansion(const MachineInstr *MI, MCInst &Inst);
 
-  typedef std::tuple<unsigned, uint32_t> HwasanMemaccessTuple;
+  typedef std::tuple<unsigned, bool, uint32_t> HwasanMemaccessTuple;
   std::map<HwasanMemaccessTuple, MCSymbol *> HwasanMemaccessSymbols;
   void LowerHWASAN_CHECK_MEMACCESS(const MachineInstr &MI);
   void LowerKCFI_CHECK(const MachineInstr &MI);
@@ -305,6 +305,7 @@ void RISCVAsmPrinter::emitInstruction(const MachineInstr *MI) {
   }
 
   switch (MI->getOpcode()) {
+  case RISCV::HWASAN_CHECK_MEMACCESS:
   case RISCV::HWASAN_CHECK_MEMACCESS_SHORTGRANULES:
     LowerHWASAN_CHECK_MEMACCESS(*MI);
     return;
@@ -522,16 +523,19 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVAsmPrinter() {
 
 void RISCVAsmPrinter::LowerHWASAN_CHECK_MEMACCESS(const MachineInstr &MI) {
   Register Reg = MI.getOperand(0).getReg();
+  bool IsShort = MI.getOpcode() == RISCV::HWASAN_CHECK_MEMACCESS_SHORTGRANULES;
   uint32_t AccessInfo = MI.getOperand(1).getImm();
   MCSymbol *&Sym =
-      HwasanMemaccessSymbols[HwasanMemaccessTuple(Reg, AccessInfo)];
+      HwasanMemaccessSymbols[HwasanMemaccessTuple(Reg, IsShort, AccessInfo)];
   if (!Sym) {
     // FIXME: Make this work on non-ELF.
     if (!TM.getTargetTriple().isOSBinFormatELF())
       report_fatal_error("llvm.hwasan.check.memaccess only supported on ELF");
 
-    std::string SymName = "__hwasan_check_x" + utostr(Reg - RISCV::X0) + "_" +
-                          utostr(AccessInfo) + "_short";
+    std::string SymName =
+        "__hwasan_check_x" + utostr(Reg - RISCV::X0) + "_" + utostr(AccessInfo);
+    if (IsShort)
+      SymName += "_short";
     Sym = OutContext.getOrCreateSymbol(SymName);
   }
   auto Res = MCSymbolRefExpr::create(Sym, MCSymbolRefExpr::VK_None, OutContext);
@@ -651,7 +655,8 @@ void RISCVAsmPrinter::EmitHwasanMemaccessSymbols(Module &M) {
 
   for (auto &P : HwasanMemaccessSymbols) {
     unsigned Reg = std::get<0>(P.first);
-    uint32_t AccessInfo = std::get<1>(P.first);
+    bool IsShort = std::get<1>(P.first);
+    uint32_t AccessInfo = std::get<2>(P.first);
     MCSymbol *Sym = P.second;
 
     unsigned Size =
@@ -712,57 +717,62 @@ void RISCVAsmPrinter::EmitHwasanMemaccessSymbols(Module &M) {
                    MCSTI);
     OutStreamer->emitLabel(HandleMismatchOrPartialSym);
 
-    EmitToStreamer(*OutStreamer,
-                   MCInstBuilder(RISCV::ADDI)
-                       .addReg(RISCV::X28)
-                       .addReg(RISCV::X0)
-                       .addImm(16),
-                   MCSTI);
-    MCSymbol *HandleMismatchSym = OutContext.createTempSymbol();
-    EmitToStreamer(
-        *OutStreamer,
-        MCInstBuilder(RISCV::BGEU)
-            .addReg(RISCV::X6)
-            .addReg(RISCV::X28)
-            .addExpr(MCSymbolRefExpr::create(HandleMismatchSym, OutContext)),
-        MCSTI);
-
-    EmitToStreamer(
-        *OutStreamer,
-        MCInstBuilder(RISCV::ANDI).addReg(RISCV::X28).addReg(Reg).addImm(0xF),
-        MCSTI);
-
-    if (Size != 1)
+    if (IsShort) {
       EmitToStreamer(*OutStreamer,
                      MCInstBuilder(RISCV::ADDI)
                          .addReg(RISCV::X28)
-                         .addReg(RISCV::X28)
-                         .addImm(Size - 1),
+                         .addReg(RISCV::X0)
+                         .addImm(16),
                      MCSTI);
-    EmitToStreamer(
-        *OutStreamer,
-        MCInstBuilder(RISCV::BGE)
-            .addReg(RISCV::X28)
-            .addReg(RISCV::X6)
-            .addExpr(MCSymbolRefExpr::create(HandleMismatchSym, OutContext)),
-        MCSTI);
+      MCSymbol *HandleMismatchSym = OutContext.createTempSymbol();
+      EmitToStreamer(
+          *OutStreamer,
+          MCInstBuilder(RISCV::BGEU)
+              .addReg(RISCV::X6)
+              .addReg(RISCV::X28)
+              .addExpr(MCSymbolRefExpr::create(HandleMismatchSym, OutContext)),
+          MCSTI);
 
-    EmitToStreamer(
-        *OutStreamer,
-        MCInstBuilder(RISCV::ORI).addReg(RISCV::X6).addReg(Reg).addImm(0xF),
-        MCSTI);
-    EmitToStreamer(
-        *OutStreamer,
-        MCInstBuilder(RISCV::LBU).addReg(RISCV::X6).addReg(RISCV::X6).addImm(0),
-        MCSTI);
-    EmitToStreamer(*OutStreamer,
-                   MCInstBuilder(RISCV::BEQ)
-                       .addReg(RISCV::X6)
-                       .addReg(RISCV::X7)
-                       .addExpr(MCSymbolRefExpr::create(ReturnSym, OutContext)),
-                   MCSTI);
+      EmitToStreamer(
+          *OutStreamer,
+          MCInstBuilder(RISCV::ANDI).addReg(RISCV::X28).addReg(Reg).addImm(0xF),
+          MCSTI);
 
-    OutStreamer->emitLabel(HandleMismatchSym);
+      if (Size != 1)
+        EmitToStreamer(*OutStreamer,
+                       MCInstBuilder(RISCV::ADDI)
+                           .addReg(RISCV::X28)
+                           .addReg(RISCV::X28)
+                           .addImm(Size - 1),
+                       MCSTI);
+      EmitToStreamer(
+          *OutStreamer,
+          MCInstBuilder(RISCV::BGE)
+              .addReg(RISCV::X28)
+              .addReg(RISCV::X6)
+              .addExpr(MCSymbolRefExpr::create(HandleMismatchSym, OutContext)),
+          MCSTI);
+
+      EmitToStreamer(
+          *OutStreamer,
+          MCInstBuilder(RISCV::ORI).addReg(RISCV::X6).addReg(Reg).addImm(0xF),
+          MCSTI);
+      EmitToStreamer(*OutStreamer,
+                     MCInstBuilder(RISCV::LBU)
+                         .addReg(RISCV::X6)
+                         .addReg(RISCV::X6)
+                         .addImm(0),
+                     MCSTI);
+      EmitToStreamer(
+          *OutStreamer,
+          MCInstBuilder(RISCV::BEQ)
+              .addReg(RISCV::X6)
+              .addReg(RISCV::X7)
+              .addExpr(MCSymbolRefExpr::create(ReturnSym, OutContext)),
+          MCSTI);
+
+      OutStreamer->emitLabel(HandleMismatchSym);
+    }
 
     // | Previous stack frames...        |
     // +=================================+ <-- [SP + 256]
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
index 5d329dceac6519..b049d3422b2bc0 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
@@ -1979,6 +1979,13 @@ def : Pat<(trap), (UNIMP)>;
 // debugger if possible.
 def : Pat<(debugtrap), (EBREAK)>;
 
+let Predicates = [IsRV64], Uses = [X5],
+    Defs = [X1, X6, X7, X28, X29, X30, X31], Size = 8 in
+def HWASAN_CHECK_MEMACCESS
+  : Pseudo<(outs), (ins GPRJALR:$ptr, i32imm:$accessinfo),
+           [(int_hwasan_check_memaccess (i64 X5), GPRJALR:$ptr,
+                                        (i32 timm:$accessinfo))]>;
+
 let Predicates = [IsRV64], Uses = [X5],
     Defs = [X1, X6, X7, X28, X29, X30, X31], Size = 8 in
 def HWASAN_CHECK_MEMACCESS_SHORTGRANULES
diff --git a/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll b/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll
index dfd526c8964137..533b307709a2e2 100644
--- a/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll
+++ b/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll
@@ -4,6 +4,34 @@
 ; RUN: llc -mtriple=riscv64 -mattr=+c --riscv-no-aliases < %s \
 ; RUN:   | FileCheck %s --check-prefix=COMPRESS
 
+define ptr @f1(ptr %x0, ptr %x1) {
+; CHECK-LABEL: f1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi sp, sp, -16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_offset ra, -8
+; CHECK-NEXT:    mv t0, a1
+; CHECK-NEXT:    call __hwasan_check_x10_1
+; CHECK-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; CHECK-NEXT:    addi sp, sp, 16
+; CHECK-NEXT:    ret
+;
+; COMPRESS-LABEL: f1:
+; COMPRESS:       # %bb.0:
+; COMPRESS-NEXT:    c.addi sp, -16
+; COMPRESS-NEXT:    .cfi_def_cfa_offset 16
+; COMPRESS-NEXT:    c.sdsp ra, 8(sp) # 8-byte Folded Spill
+; COMPRESS-NEXT:    .cfi_offset ra, -8
+; COMPRESS-NEXT:    c.mv t0, a1
+; COMPRESS-NEXT:    call __hwasan_check_x10_1
+; COMPRESS-NEXT:    c.ldsp ra, 8(sp) # 8-byte Folded Reload
+; COMPRESS-NEXT:    c.addi sp, 16
+; COMPRESS-NEXT:    c.jr ra
+  call void @llvm.hwasan.check.memaccess(ptr %x1, ptr %x0, i32 1)
+  ret ptr %x0
+}
+
 define ptr @f2(ptr %x0, ptr %x1) {
 ; CHECK-LABEL: f2:
 ; CHECK:       # %bb.0:
@@ -32,6 +60,52 @@ define ptr @f2(ptr %x0, ptr %x1) {
   ret ptr %x0
 }
 
+declare void @llvm.hwasan.check.memaccess(ptr, ptr, i32)
+
+; CHECK: .section        .text.hot,"axG", at progbits,__hwasan_check_x10_1,comdat
+; CHECK-NEXT: .type   __hwasan_check_x10_1, at function
+; CHECK-NEXT: .weak   __hwasan_check_x10_1
+; CHECK-NEXT: .hidden __hwasan_check_x10_1
+; CHECK-NEXT: __hwasan_check_x10_1:
+; CHECK-NEXT: slli    t1, a0, 8
+; CHECK-NEXT: srli    t1, t1, 12
+; CHECK-NEXT: add     t1, t0, t1
+; CHECK-NEXT: lbu     t1, 0(t1)
+; CHECK-NEXT: srli    t2, a0, 56
+; CHECK-NEXT: bne     t2, t1, .Ltmp0
+; CHECK-NEXT: .Ltmp1:
+; CHECK-NEXT: ret
+; CHECK-NEXT: .Ltmp0:
+; CHECK-NEXT: addi    sp, sp, -256
+; CHECK-NEXT: sd      a0, 80(sp)
+; CHECK-NEXT: sd      a1, 88(sp)
+; CHECK-NEXT: sd      s0, 64(sp)
+; CHECK-NEXT: sd      ra, 8(sp)
+; CHECK-NEXT: li      a1, 1
+; CHECK-NEXT: call    __hwasan_tag_mismatch_v2
+
+; COMPRESS: .section        .text.hot,"axG", at progbits,__hwasan_check_x10_1,comdat
+; COMPRESS-NEXT: .type   __hwasan_check_x10_1, at function
+; COMPRESS-NEXT: .weak   __hwasan_check_x10_1
+; COMPRESS-NEXT: .hidden __hwasan_check_x10_1
+; COMPRESS-NEXT: __hwasan_check_x10_1:
+; COMPRESS-NEXT: slli    t1, a0, 8
+; COMPRESS-NEXT: srli    t1, t1, 12
+; COMPRESS-NEXT: c.add   t1, t0
+; COMPRESS-NEXT: lbu     t1, 0(t1)
+; COMPRESS-NEXT: srli    t2, a0, 56
+; COMPRESS-NEXT: bne     t2, t1, .Ltmp0
+; COMPRESS-NEXT: .Ltmp1:
+; COMPRESS-NEXT: c.jr    ra
+; COMPRESS-NEXT: .Ltmp0:
+; COMPRESS-NEXT: c.addi16sp sp, -256
+; COMPRESS-NEXT: c.sdsp a0, 80(sp)
+; COMPRESS-NEXT: c.sdsp a1, 88(sp)
+; COMPRESS-NEXT: c.sdsp s0, 64(sp)
+; COMPRESS-NEXT: c.sdsp ra, 8(sp)
+; COMPRESS-NEXT: c.li    a1, 1
+; COMPRESS-NEXT: call    __hwasan_tag_mismatch_v2
+
 declare void @llvm.hwasan.check.memaccess.shortgranules(ptr, ptr, i32)
 
 ; CHECK: .section        .text.hot,"axG", at progbits,__hwasan_check_x10_2_short,comdat
@@ -44,19 +118,19 @@ declare void @llvm.hwasan.check.memaccess.shortgranules(ptr, ptr, i32)
 ; CHECK-NEXT: add     t1, t0, t1
 ; CHECK-NEXT: lbu     t1, 0(t1)
 ; CHECK-NEXT: srli    t2, a0, 56
-; CHECK-NEXT: bne     t2, t1, .Ltmp0
-; CHECK-NEXT: .Ltmp1:
+; CHECK-NEXT: bne     t2, t1, .Ltmp2
+; CHECK-NEXT: .Ltmp3:
 ; CHECK-NEXT: ret
-; CHECK-NEXT: .Ltmp0:
+; CHECK-NEXT: .Ltmp2:
 ; CHECK-NEXT: li      t3, 16
-; CHECK-NEXT: bgeu    t1, t3, .Ltmp2
+; CHECK-NEXT: bgeu    t1, t3, .Ltmp4
 ; CHECK-NEXT: andi    t3, a0, 15
 ; CHECK-NEXT: addi    t3, t3, 3
-; CHECK-NEXT: bge     t3, t1, .Ltmp2
+; CHECK-NEXT: bge     t3, t1, .Ltmp4
 ; CHECK-NEXT: ori     t1, a0, 15
 ; CHECK-NEXT: lbu     t1, 0(t1)
-; CHECK-NEXT: beq     t1, t2, .Ltmp1
-; CHECK-NEXT: .Ltmp2:
+; CHECK-NEXT: beq     t1, t2, .Ltmp3
+; CHECK-NEXT: .Ltmp4:
 ; CHECK-NEXT: addi    sp, sp, -256
 ; CHECK-NEXT: sd      a0, 80(sp)
 ; CHECK-NEXT: sd      a1, 88(sp)
@@ -75,19 +149,19 @@ declare void @llvm.hwasan.check.memaccess.shortgranules(ptr, ptr, i32)
 ; COMPRESS-NEXT: c.add   t1, t0
 ; COMPRESS-NEXT: lbu     t1, 0(t1)
 ; COMPRESS-NEXT: srli    t2, a0, 56
-; COMPRESS-NEXT: bne     t2, t1, .Ltmp0
-; COMPRESS-NEXT: .Ltmp1:
+; COMPRESS-NEXT: bne     t2, t1, .Ltmp2
+; COMPRESS-NEXT: .Ltmp3:
 ; COMPRESS-NEXT: c.jr    ra
-; COMPRESS-NEXT: .Ltmp0:
+; COMPRESS-NEXT: .Ltmp2:
 ; COMPRESS-NEXT: c.li    t3, 16
-; COMPRESS-NEXT: bgeu    t1, t3, .Ltmp2
+; COMPRESS-NEXT: bgeu    t1, t3, .Ltmp4
 ; COMPRESS-NEXT: andi    t3, a0, 15
 ; COMPRESS-NEXT: c.addi  t3, 3
-; COMPRESS-NEXT: bge     t3, t1, .Ltmp2
+; COMPRESS-NEXT: bge     t3, t1, .Ltmp4
 ; COMPRESS-NEXT: ori     t1, a0, 15
 ; COMPRESS-NEXT: lbu     t1, 0(t1)
-; COMPRESS-NEXT: beq     t1, t2, .Ltmp1
-; COMPRESS-NEXT: .Ltmp2:
+; COMPRESS-NEXT: beq     t1, t2, .Ltmp3
+; COMPRESS-NEXT: .Ltmp4:
 ; COMPRESS-NEXT: c.addi16sp sp, -256
 ; COMPRESS-NEXT: c.sdsp a0, 80(sp)
 ; COMPRESS-NEXT: c.sdsp a1, 88(sp)



More information about the llvm-commits mailing list