[llvm] [RISCV] Use RISCVAsmPrinter::EmitToStreamer for EmitHwasanMemaccessSymbols. (PR #111792)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 9 22:12:35 PDT 2024


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/111792

Add a MCSubtargetInfo& operand so we can control the subtarget for this use. The old signature is kept as a wrapper to pass *STI to maintain compatibility.

By using EmitToStreamer we are able to compress the instructions when possible.

cc: @SiFiveHolland 

>From 9ba4d68928748138c9a899c980b0f36fa0660a7f Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 9 Oct 2024 21:55:57 -0700
Subject: [PATCH] [RISCV] Use RISCVAsmPrinter::EmitToStreamer for
 EmitHwasanMemaccessSymbols.

Add a MCSubtargetInfo& operand so we can control the subtarget for
this use. The old signature is kept as a wrapper to pass *STI to
maintain compatibility.

By using EmitToStreamer we are able to compress the instructions
when possible.
---
 llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp     | 184 ++++++++++--------
 .../CodeGen/RISCV/hwasan-check-memaccess.ll   |  45 +++++
 2 files changed, 148 insertions(+), 81 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
index 384a7cf59f0632..5ad09ae7290fc5 100644
--- a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
+++ b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
@@ -86,7 +86,11 @@ class RISCVAsmPrinter : public AsmPrinter {
                              const char *ExtraCode, raw_ostream &OS) override;
 
   // Returns whether Inst is compressed.
-  bool EmitToStreamer(MCStreamer &S, const MCInst &Inst);
+  bool EmitToStreamer(MCStreamer &S, const MCInst &Inst,
+                      const MCSubtargetInfo &SubtargetInfo);
+  bool EmitToStreamer(MCStreamer &S, const MCInst &Inst) {
+    return EmitToStreamer(S, Inst, *STI);
+  }
 
   bool lowerPseudoInstExpansion(const MachineInstr *MI, MCInst &Inst);
 
@@ -242,12 +246,13 @@ void RISCVAsmPrinter::LowerSTATEPOINT(MCStreamer &OutStreamer, StackMaps &SM,
   SM.recordStatepoint(*MILabel, MI);
 }
 
-bool RISCVAsmPrinter::EmitToStreamer(MCStreamer &S, const MCInst &Inst) {
+bool RISCVAsmPrinter::EmitToStreamer(MCStreamer &S, const MCInst &Inst,
+                                     const MCSubtargetInfo &SubtargetInfo) {
   MCInst CInst;
-  bool Res = RISCVRVC::compress(CInst, Inst, *STI);
+  bool Res = RISCVRVC::compress(CInst, Inst, SubtargetInfo);
   if (Res)
     ++RISCVNumInstrsCompressed;
-  S.emitInstruction(Res ? CInst : Inst, *STI);
+  S.emitInstruction(Res ? CInst : Inst, SubtargetInfo);
   return Res;
 }
 
@@ -662,87 +667,100 @@ void RISCVAsmPrinter::EmitHwasanMemaccessSymbols(Module &M) {
     OutStreamer->emitLabel(Sym);
 
     // Extract shadow offset from ptr
-    OutStreamer->emitInstruction(
+    EmitToStreamer(
+        *OutStreamer,
         MCInstBuilder(RISCV::SLLI).addReg(RISCV::X6).addReg(Reg).addImm(8),
         MCSTI);
-    OutStreamer->emitInstruction(MCInstBuilder(RISCV::SRLI)
-                                     .addReg(RISCV::X6)
-                                     .addReg(RISCV::X6)
-                                     .addImm(12),
-                                 MCSTI);
+    EmitToStreamer(*OutStreamer,
+                   MCInstBuilder(RISCV::SRLI)
+                       .addReg(RISCV::X6)
+                       .addReg(RISCV::X6)
+                       .addImm(12),
+                   MCSTI);
     // load shadow tag in X6, X5 contains shadow base
-    OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADD)
-                                     .addReg(RISCV::X6)
-                                     .addReg(RISCV::X5)
-                                     .addReg(RISCV::X6),
-                                 MCSTI);
-    OutStreamer->emitInstruction(
+    EmitToStreamer(*OutStreamer,
+                   MCInstBuilder(RISCV::ADD)
+                       .addReg(RISCV::X6)
+                       .addReg(RISCV::X5)
+                       .addReg(RISCV::X6),
+                   MCSTI);
+    EmitToStreamer(
+        *OutStreamer,
         MCInstBuilder(RISCV::LBU).addReg(RISCV::X6).addReg(RISCV::X6).addImm(0),
         MCSTI);
     // Extract tag from pointer and compare it with loaded tag from shadow
-    OutStreamer->emitInstruction(
+    EmitToStreamer(
+        *OutStreamer,
         MCInstBuilder(RISCV::SRLI).addReg(RISCV::X7).addReg(Reg).addImm(56),
         MCSTI);
     MCSymbol *HandleMismatchOrPartialSym = OutContext.createTempSymbol();
     // X7 contains tag from the pointer, while X6 contains tag from memory
-    OutStreamer->emitInstruction(
-        MCInstBuilder(RISCV::BNE)
-            .addReg(RISCV::X7)
-            .addReg(RISCV::X6)
-            .addExpr(MCSymbolRefExpr::create(HandleMismatchOrPartialSym,
-                                             OutContext)),
-        MCSTI);
+    EmitToStreamer(*OutStreamer,
+                   MCInstBuilder(RISCV::BNE)
+                       .addReg(RISCV::X7)
+                       .addReg(RISCV::X6)
+                       .addExpr(MCSymbolRefExpr::create(
+                           HandleMismatchOrPartialSym, OutContext)),
+                   MCSTI);
     MCSymbol *ReturnSym = OutContext.createTempSymbol();
     OutStreamer->emitLabel(ReturnSym);
-    OutStreamer->emitInstruction(MCInstBuilder(RISCV::JALR)
-                                     .addReg(RISCV::X0)
-                                     .addReg(RISCV::X1)
-                                     .addImm(0),
-                                 MCSTI);
+    EmitToStreamer(*OutStreamer,
+                   MCInstBuilder(RISCV::JALR)
+                       .addReg(RISCV::X0)
+                       .addReg(RISCV::X1)
+                       .addImm(0),
+                   MCSTI);
     OutStreamer->emitLabel(HandleMismatchOrPartialSym);
 
-    OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADDI)
-                                     .addReg(RISCV::X28)
-                                     .addReg(RISCV::X0)
-                                     .addImm(16),
-                                 MCSTI);
+    EmitToStreamer(*OutStreamer,
+                   MCInstBuilder(RISCV::ADDI)
+                       .addReg(RISCV::X28)
+                       .addReg(RISCV::X0)
+                       .addImm(16),
+                   MCSTI);
     MCSymbol *HandleMismatchSym = OutContext.createTempSymbol();
-    OutStreamer->emitInstruction(
+    EmitToStreamer(
+        *OutStreamer,
         MCInstBuilder(RISCV::BGEU)
             .addReg(RISCV::X6)
             .addReg(RISCV::X28)
             .addExpr(MCSymbolRefExpr::create(HandleMismatchSym, OutContext)),
         MCSTI);
 
-    OutStreamer->emitInstruction(
+    EmitToStreamer(
+        *OutStreamer,
         MCInstBuilder(RISCV::ANDI).addReg(RISCV::X28).addReg(Reg).addImm(0xF),
         MCSTI);
 
     if (Size != 1)
-      OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADDI)
-                                       .addReg(RISCV::X28)
-                                       .addReg(RISCV::X28)
-                                       .addImm(Size - 1),
-                                   MCSTI);
-    OutStreamer->emitInstruction(
+      EmitToStreamer(*OutStreamer,
+                     MCInstBuilder(RISCV::ADDI)
+                         .addReg(RISCV::X28)
+                         .addReg(RISCV::X28)
+                         .addImm(Size - 1),
+                     MCSTI);
+    EmitToStreamer(
+        *OutStreamer,
         MCInstBuilder(RISCV::BGE)
             .addReg(RISCV::X28)
             .addReg(RISCV::X6)
             .addExpr(MCSymbolRefExpr::create(HandleMismatchSym, OutContext)),
         MCSTI);
 
-    OutStreamer->emitInstruction(
+    EmitToStreamer(
+        *OutStreamer,
         MCInstBuilder(RISCV::ORI).addReg(RISCV::X6).addReg(Reg).addImm(0xF),
         MCSTI);
-    OutStreamer->emitInstruction(
+    EmitToStreamer(
+        *OutStreamer,
         MCInstBuilder(RISCV::LBU).addReg(RISCV::X6).addReg(RISCV::X6).addImm(0),
         MCSTI);
-    OutStreamer->emitInstruction(
-        MCInstBuilder(RISCV::BEQ)
-            .addReg(RISCV::X6)
-            .addReg(RISCV::X7)
-            .addExpr(MCSymbolRefExpr::create(ReturnSym, OutContext)),
-        MCSTI);
+    EmitToStreamer(*OutStreamer,
+                   MCInstBuilder(RISCV::BEQ)
+                       .addReg(RISCV::X6)
+                       .addReg(RISCV::X7)
+                       .addExpr(MCSymbolRefExpr::create(ReturnSym, OutContext)),
+                   MCSTI);
 
     OutStreamer->emitLabel(HandleMismatchSym);
 
@@ -781,50 +799,54 @@ void RISCVAsmPrinter::EmitHwasanMemaccessSymbols(Module &M) {
     // +---------------------------------+ <-- [x2 / SP]
 
     // Adjust sp
-    OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADDI)
-                                     .addReg(RISCV::X2)
-                                     .addReg(RISCV::X2)
-                                     .addImm(-256),
-                                 MCSTI);
+    EmitToStreamer(*OutStreamer,
+                   MCInstBuilder(RISCV::ADDI)
+                       .addReg(RISCV::X2)
+                       .addReg(RISCV::X2)
+                       .addImm(-256),
+                   MCSTI);
 
     // store x10(arg0) by new sp
-    OutStreamer->emitInstruction(MCInstBuilder(RISCV::SD)
-                                     .addReg(RISCV::X10)
-                                     .addReg(RISCV::X2)
-                                     .addImm(8 * 10),
-                                 MCSTI);
+    EmitToStreamer(*OutStreamer,
+                   MCInstBuilder(RISCV::SD)
+                       .addReg(RISCV::X10)
+                       .addReg(RISCV::X2)
+                       .addImm(8 * 10),
+                   MCSTI);
     // store x11(arg1) by new sp
-    OutStreamer->emitInstruction(MCInstBuilder(RISCV::SD)
-                                     .addReg(RISCV::X11)
-                                     .addReg(RISCV::X2)
-                                     .addImm(8 * 11),
-                                 MCSTI);
+    EmitToStreamer(*OutStreamer,
+                   MCInstBuilder(RISCV::SD)
+                       .addReg(RISCV::X11)
+                       .addReg(RISCV::X2)
+                       .addImm(8 * 11),
+                   MCSTI);
 
     // store x8(fp) by new sp
-    OutStreamer->emitInstruction(
+    EmitToStreamer(
+        *OutStreamer,
         MCInstBuilder(RISCV::SD).addReg(RISCV::X8).addReg(RISCV::X2).addImm(8 *
                                                                             8),
         MCSTI);
     // store x1(ra) by new sp
-    OutStreamer->emitInstruction(
+    EmitToStreamer(
+        *OutStreamer,
         MCInstBuilder(RISCV::SD).addReg(RISCV::X1).addReg(RISCV::X2).addImm(1 *
                                                                             8),
         MCSTI);
     if (Reg != RISCV::X10)
-      OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADDI)
-                                       .addReg(RISCV::X10)
-                                       .addReg(Reg)
-                                       .addImm(0),
-                                   MCSTI);
-    OutStreamer->emitInstruction(
-        MCInstBuilder(RISCV::ADDI)
-            .addReg(RISCV::X11)
-            .addReg(RISCV::X0)
-            .addImm(AccessInfo & HWASanAccessInfo::RuntimeMask),
-        MCSTI);
-
-    OutStreamer->emitInstruction(MCInstBuilder(RISCV::PseudoCALL).addExpr(Expr),
-                                 MCSTI);
+      EmitToStreamer(
+          *OutStreamer,
+          MCInstBuilder(RISCV::ADDI).addReg(RISCV::X10).addReg(Reg).addImm(0),
+          MCSTI);
+    EmitToStreamer(*OutStreamer,
+                   MCInstBuilder(RISCV::ADDI)
+                       .addReg(RISCV::X11)
+                       .addReg(RISCV::X0)
+                       .addImm(AccessInfo & HWASanAccessInfo::RuntimeMask),
+                   MCSTI);
+
+    EmitToStreamer(*OutStreamer, MCInstBuilder(RISCV::PseudoCALL).addExpr(Expr),
+                   MCSTI);
   }
 }
 
diff --git a/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll b/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll
index 12c95206d21bed..dfd526c8964137 100644
--- a/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll
+++ b/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll
@@ -1,6 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc -mtriple=riscv64 < %s | FileCheck %s
 ; RUN: llc -mtriple=riscv64 --relocation-model=pic < %s | FileCheck %s
+; RUN: llc -mtriple=riscv64 -mattr=+c --riscv-no-aliases < %s \
+; RUN:   | FileCheck %s --check-prefix=COMPRESS
 
 define ptr @f2(ptr %x0, ptr %x1) {
 ; CHECK-LABEL: f2:
@@ -14,6 +16,18 @@ define ptr @f2(ptr %x0, ptr %x1) {
 ; CHECK-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
 ; CHECK-NEXT:    addi sp, sp, 16
 ; CHECK-NEXT:    ret
+;
+; COMPRESS-LABEL: f2:
+; COMPRESS:       # %bb.0:
+; COMPRESS-NEXT:    c.addi sp, -16
+; COMPRESS-NEXT:    .cfi_def_cfa_offset 16
+; COMPRESS-NEXT:    c.sdsp ra, 8(sp) # 8-byte Folded Spill
+; COMPRESS-NEXT:    .cfi_offset ra, -8
+; COMPRESS-NEXT:    c.mv t0, a1
+; COMPRESS-NEXT:    call __hwasan_check_x10_2_short
+; COMPRESS-NEXT:    c.ldsp ra, 8(sp) # 8-byte Folded Reload
+; COMPRESS-NEXT:    c.addi sp, 16
+; COMPRESS-NEXT:    c.jr ra
   call void @llvm.hwasan.check.memaccess.shortgranules(ptr %x1, ptr %x0, i32 2)
   ret ptr %x0
 }
@@ -50,3 +64,34 @@ declare void @llvm.hwasan.check.memaccess.shortgranules(ptr, ptr, i32)
 ; CHECK-NEXT: sd      ra, 8(sp)
 ; CHECK-NEXT: li      a1, 2
 ; CHECK-NEXT: call    __hwasan_tag_mismatch_v2
+
+; COMPRESS: .section        .text.hot,"axG", at progbits,__hwasan_check_x10_2_short,comdat
+; COMPRESS-NEXT: .type   __hwasan_check_x10_2_short, at function
+; COMPRESS-NEXT: .weak   __hwasan_check_x10_2_short
+; COMPRESS-NEXT: .hidden __hwasan_check_x10_2_short
+; COMPRESS-NEXT: __hwasan_check_x10_2_short:
+; COMPRESS-NEXT: slli    t1, a0, 8
+; COMPRESS-NEXT: srli    t1, t1, 12
+; COMPRESS-NEXT: c.add   t1, t0
+; COMPRESS-NEXT: lbu     t1, 0(t1)
+; COMPRESS-NEXT: srli    t2, a0, 56
+; COMPRESS-NEXT: bne     t2, t1, .Ltmp0
+; COMPRESS-NEXT: .Ltmp1:
+; COMPRESS-NEXT: c.jr    ra
+; COMPRESS-NEXT: .Ltmp0:
+; COMPRESS-NEXT: c.li    t3, 16
+; COMPRESS-NEXT: bgeu    t1, t3, .Ltmp2
+; COMPRESS-NEXT: andi    t3, a0, 15
+; COMPRESS-NEXT: c.addi  t3, 3
+; COMPRESS-NEXT: bge     t3, t1, .Ltmp2
+; COMPRESS-NEXT: ori     t1, a0, 15
+; COMPRESS-NEXT: lbu     t1, 0(t1)
+; COMPRESS-NEXT: beq     t1, t2, .Ltmp1
+; COMPRESS-NEXT: .Ltmp2:
+; COMPRESS-NEXT: c.addi16sp sp, -256
+; COMPRESS-NEXT: c.sdsp a0, 80(sp)
+; COMPRESS-NEXT: c.sdsp a1, 88(sp)
+; COMPRESS-NEXT: c.sdsp s0, 64(sp)
+; COMPRESS-NEXT: c.sdsp ra, 8(sp)
+; COMPRESS-NEXT: c.li    a1, 2
+; COMPRESS-NEXT: call    __hwasan_tag_mismatch_v2



More information about the llvm-commits mailing list