[llvm] [RISCV] Add Stackmap/Statepoint/Patchpoint support with targets (PR #77337)

Sacha Coppey via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 12 10:56:16 PST 2024


https://github.com/Zeavee updated https://github.com/llvm/llvm-project/pull/77337

>From 5783e417a09909834df54cbc13f879cca7e8b964 Mon Sep 17 00:00:00 2001
From: Sacha Coppey <sacha.coppey at oracle.com>
Date: Mon, 22 May 2023 21:26:37 +0200
Subject: [PATCH 1/2] [RISCV] Add Stackmap/Statepoint/Patchpoint support with
 targets

---
 .../Target/RISCV/AsmParser/RISCVAsmParser.cpp | 31 ++--------
 .../Target/RISCV/MCTargetDesc/RISCVMatInt.cpp | 41 +++++++++++++
 .../Target/RISCV/MCTargetDesc/RISCVMatInt.h   |  7 +++
 llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp     | 57 +++++++++++++++++++
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 11 ++++
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp      |  9 ++-
 llvm/test/CodeGen/RISCV/rv64-patchpoint.ll    | 46 ++++++++++++++-
 7 files changed, 172 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
index d616aaeddf4114..494a21b49e7e38 100644
--- a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
+++ b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
@@ -2999,34 +2999,11 @@ void RISCVAsmParser::emitToStreamer(MCStreamer &S, const MCInst &Inst) {
 
 void RISCVAsmParser::emitLoadImm(MCRegister DestReg, int64_t Value,
                                  MCStreamer &Out) {
-  RISCVMatInt::InstSeq Seq = RISCVMatInt::generateInstSeq(Value, getSTI());
-
-  MCRegister SrcReg = RISCV::X0;
-  for (const RISCVMatInt::Inst &Inst : Seq) {
-    switch (Inst.getOpndKind()) {
-    case RISCVMatInt::Imm:
-      emitToStreamer(Out,
-                     MCInstBuilder(Inst.getOpcode()).addReg(DestReg).addImm(Inst.getImm()));
-      break;
-    case RISCVMatInt::RegX0:
-      emitToStreamer(
-          Out, MCInstBuilder(Inst.getOpcode()).addReg(DestReg).addReg(SrcReg).addReg(
-                   RISCV::X0));
-      break;
-    case RISCVMatInt::RegReg:
-      emitToStreamer(
-          Out, MCInstBuilder(Inst.getOpcode()).addReg(DestReg).addReg(SrcReg).addReg(
-                   SrcReg));
-      break;
-    case RISCVMatInt::RegImm:
-      emitToStreamer(
-          Out, MCInstBuilder(Inst.getOpcode()).addReg(DestReg).addReg(SrcReg).addImm(
-                   Inst.getImm()));
-      break;
-    }
+  SmallVector<MCInst, 8> Seq =
+      RISCVMatInt::generateMCInstSeq(Value, getSTI(), DestReg);
 
-    // Only the first instruction has X0 as its source.
-    SrcReg = DestReg;
+  for (MCInst &Inst : Seq) {
+    emitToStreamer(Out, Inst);
   }
 }
 
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.cpp
index 4358a5b878e631..d873d03d4b11c1 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.cpp
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.cpp
@@ -9,6 +9,7 @@
 #include "RISCVMatInt.h"
 #include "MCTargetDesc/RISCVMCTargetDesc.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/MC/MCInstBuilder.h"
 #include "llvm/Support/MathExtras.h"
 using namespace llvm;
 
@@ -469,6 +470,46 @@ InstSeq generateTwoRegInstSeq(int64_t Val, const MCSubtargetInfo &STI,
   return RISCVMatInt::InstSeq();
 }
 
+SmallVector<MCInst, 8>
+generateMCInstSeq(int64_t Val, const MCSubtargetInfo &STI, MCRegister DestReg) {
+  RISCVMatInt::InstSeq Seq = RISCVMatInt::generateInstSeq(Val, STI);
+
+  SmallVector<MCInst, 8> instructions;
+
+  MCRegister SrcReg = RISCV::X0;
+  for (RISCVMatInt::Inst &Inst : Seq) {
+    switch (Inst.getOpndKind()) {
+    case RISCVMatInt::Imm:
+      instructions.push_back(MCInstBuilder(Inst.getOpcode())
+                                 .addReg(DestReg)
+                                 .addImm(Inst.getImm()));
+      break;
+    case RISCVMatInt::RegX0:
+      instructions.push_back(MCInstBuilder(Inst.getOpcode())
+                                 .addReg(DestReg)
+                                 .addReg(SrcReg)
+                                 .addReg(RISCV::X0));
+      break;
+    case RISCVMatInt::RegReg:
+      instructions.push_back(MCInstBuilder(Inst.getOpcode())
+                                 .addReg(DestReg)
+                                 .addReg(SrcReg)
+                                 .addReg(SrcReg));
+      break;
+    case RISCVMatInt::RegImm:
+      instructions.push_back(MCInstBuilder(Inst.getOpcode())
+                                 .addReg(DestReg)
+                                 .addReg(SrcReg)
+                                 .addImm(Inst.getImm()));
+      break;
+    }
+
+    // Only the first instruction has X0 as its source.
+    SrcReg = DestReg;
+  }
+  return instructions;
+}
+
 int getIntMatCost(const APInt &Val, unsigned Size, const MCSubtargetInfo &STI,
                   bool CompressionCost) {
   bool IsRV64 = STI.hasFeature(RISCV::Feature64Bit);
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.h
index 780f685463f300..edd4cdeba8a25a 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.h
@@ -10,7 +10,10 @@
 #define LLVM_LIB_TARGET_RISCV_MCTARGETDESC_MATINT_H
 
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/MC/MCInst.h"
+#include "llvm/MC/MCRegisterInfo.h"
 #include "llvm/MC/MCSubtargetInfo.h"
+#include "llvm/TargetParser/SubtargetFeature.h"
 #include <cstdint>
 
 namespace llvm {
@@ -56,6 +59,10 @@ InstSeq generateInstSeq(int64_t Val, const MCSubtargetInfo &STI);
 InstSeq generateTwoRegInstSeq(int64_t Val, const MCSubtargetInfo &STI,
                               unsigned &ShiftAmt, unsigned &AddOpc);
 
+// Helper to generate the generateInstSeq instruction sequence using MCInsts
+SmallVector<MCInst, 8>
+generateMCInstSeq(int64_t Val, const MCSubtargetInfo &STI, MCRegister DestReg);
+
 // Helper to estimate the number of instructions required to materialise the
 // given immediate value into a register. This estimate does not account for
 // `Val` possibly fitting into an immediate, and so may over-estimate.
diff --git a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
index f2bd5118fc0717..c8975a9cdc4110 100644
--- a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
+++ b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
@@ -14,6 +14,7 @@
 #include "MCTargetDesc/RISCVBaseInfo.h"
 #include "MCTargetDesc/RISCVInstPrinter.h"
 #include "MCTargetDesc/RISCVMCExpr.h"
+#include "MCTargetDesc/RISCVMatInt.h"
 #include "MCTargetDesc/RISCVTargetStreamer.h"
 #include "RISCV.h"
 #include "RISCVMachineFunctionInfo.h"
@@ -152,8 +153,35 @@ void RISCVAsmPrinter::LowerPATCHPOINT(MCStreamer &OutStreamer, StackMaps &SM,
 
   PatchPointOpers Opers(&MI);
 
+  const MachineOperand &CalleeMO = Opers.getCallTarget();
   unsigned EncodedBytes = 0;
 
+  if (CalleeMO.isImm()) {
+    uint64_t CallTarget = CalleeMO.getImm();
+    if (CallTarget) {
+      assert((CallTarget & 0xFFFF'FFFF'FFFF) == CallTarget &&
+             "High 16 bits of call target should be zero.");
+      // Materialize the jump address:
+      SmallVector<MCInst, 8> Seq =
+          RISCVMatInt::generateMCInstSeq(CallTarget, *STI, RISCV::X1);
+      for (MCInst &Inst : Seq) {
+        EmitToStreamer(OutStreamer, Inst);
+      }
+      EncodedBytes += Seq.size() * 4;
+      EmitToStreamer(OutStreamer, MCInstBuilder(RISCV::JALR)
+                                      .addReg(RISCV::X1)
+                                      .addReg(RISCV::X1)
+                                      .addImm(0));
+      EncodedBytes += 4;
+    }
+  } else if (CalleeMO.isGlobal()) {
+    MCOperand CallTargetMCOp;
+    lowerOperand(CalleeMO, CallTargetMCOp);
+    EmitToStreamer(OutStreamer,
+                   MCInstBuilder(RISCV::PseudoCALL).addOperand(CallTargetMCOp));
+    EncodedBytes += 8;
+  }
+
   // Emit padding.
   unsigned NumBytes = Opers.getNumPatchBytes();
   assert(NumBytes >= EncodedBytes &&
@@ -172,6 +200,35 @@ void RISCVAsmPrinter::LowerSTATEPOINT(MCStreamer &OutStreamer, StackMaps &SM,
     assert(PatchBytes % NOPBytes == 0 &&
            "Invalid number of NOP bytes requested!");
     emitNops(PatchBytes / NOPBytes);
+  } else {
+    // Lower call target and choose correct opcode
+    const MachineOperand &CallTarget = SOpers.getCallTarget();
+    MCOperand CallTargetMCOp;
+    switch (CallTarget.getType()) {
+    case MachineOperand::MO_GlobalAddress:
+    case MachineOperand::MO_ExternalSymbol:
+      lowerOperand(CallTarget, CallTargetMCOp);
+      EmitToStreamer(
+          OutStreamer,
+          MCInstBuilder(RISCV::PseudoCALL).addOperand(CallTargetMCOp));
+      break;
+    case MachineOperand::MO_Immediate:
+      CallTargetMCOp = MCOperand::createImm(CallTarget.getImm());
+      EmitToStreamer(OutStreamer, MCInstBuilder(RISCV::JAL)
+                                      .addReg(RISCV::X1)
+                                      .addOperand(CallTargetMCOp));
+      break;
+    case MachineOperand::MO_Register:
+      CallTargetMCOp = MCOperand::createReg(CallTarget.getReg());
+      EmitToStreamer(OutStreamer, MCInstBuilder(RISCV::JALR)
+                                      .addReg(RISCV::X1)
+                                      .addOperand(CallTargetMCOp)
+                                      .addImm(0));
+      break;
+    default:
+      llvm_unreachable("Unsupported operand type in statepoint call target");
+      break;
+    }
   }
 
   auto &Ctx = OutStreamer.getContext();
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 79c16cf4c4c361..9efa15af2d4535 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16969,6 +16969,17 @@ RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
   case RISCV::PseudoFROUND_D_IN32X:
     return emitFROUND(MI, BB, Subtarget);
   case TargetOpcode::STATEPOINT:
+    // STATEPOINT is a pseudo instruction which has no implicit defs/uses
+    // while jal call instruction (where statepoint will be lowered at the end)
+    // has implicit def. This def is early-clobber as it will be set at
+    // the moment of the call and earlier than any use is read.
+    // Add this implicit dead def here as a workaround.
+    MI.addOperand(*MI.getMF(),
+                  MachineOperand::CreateReg(
+                      RISCV::X1, /*isDef*/ true,
+                      /*isImp*/ true, /*isKill*/ false, /*isDead*/ true,
+                      /*isUndef*/ false, /*isEarlyClobber*/ true));
+    [[fallthrough]];
   case TargetOpcode::STACKMAP:
   case TargetOpcode::PATCHPOINT:
     if (!Subtarget.is64Bit())
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 7f6a045a7d042f..37cd378cb0eaed 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1524,9 +1524,14 @@ unsigned RISCVInstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
   case TargetOpcode::PATCHPOINT:
     // The size of the patchpoint intrinsic is the number of bytes requested
     return PatchPointOpers(&MI).getNumPatchBytes();
-  case TargetOpcode::STATEPOINT:
+  case TargetOpcode::STATEPOINT: {
     // The size of the statepoint intrinsic is the number of bytes requested
-    return StatepointOpers(&MI).getNumPatchBytes();
+    unsigned NumBytes = StatepointOpers(&MI).getNumPatchBytes();
+    // A statepoint is at least a PseudoCALL
+    if (NumBytes < 8)
+      NumBytes = 8;
+    return NumBytes;
+  }
   default:
     return get(Opcode).getSize();
   }
diff --git a/llvm/test/CodeGen/RISCV/rv64-patchpoint.ll b/llvm/test/CodeGen/RISCV/rv64-patchpoint.ll
index d2a3bccfef7bb0..10f80162af0cc1 100644
--- a/llvm/test/CodeGen/RISCV/rv64-patchpoint.ll
+++ b/llvm/test/CodeGen/RISCV/rv64-patchpoint.ll
@@ -1,12 +1,56 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc -mtriple=riscv64 -debug-entry-values -enable-misched=0 < %s | FileCheck %s
 
+; Trivial patchpoint codegen
+;
+define i64 @trivial_patchpoint_codegen(i64 %p1, i64 %p2, i64 %p3, i64 %p4) {
+; CHECK-LABEL: trivial_patchpoint_codegen:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    addi sp, sp, -16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    sd s0, 8(sp) # 8-byte Folded Spill
+; CHECK-NEXT:    sd s1, 0(sp) # 8-byte Folded Spill
+; CHECK-NEXT:    .cfi_offset s0, -8
+; CHECK-NEXT:    .cfi_offset s1, -16
+; CHECK-NEXT:    mv s0, a0
+; CHECK-NEXT:  .Ltmp0:
+; CHECK-NEXT:    lui ra, 3563
+; CHECK-NEXT:    addiw ra, ra, -577
+; CHECK-NEXT:    slli ra, ra, 12
+; CHECK-NEXT:    addi ra, ra, -259
+; CHECK-NEXT:    slli ra, ra, 12
+; CHECK-NEXT:    addi ra, ra, -1282
+; CHECK-NEXT:    jalr ra
+; CHECK-NEXT:    mv s1, a0
+; CHECK-NEXT:    mv a0, s0
+; CHECK-NEXT:    mv a1, s1
+; CHECK-NEXT:  .Ltmp1:
+; CHECK-NEXT:    lui ra, 3563
+; CHECK-NEXT:    addiw ra, ra, -577
+; CHECK-NEXT:    slli ra, ra, 12
+; CHECK-NEXT:    addi ra, ra, -259
+; CHECK-NEXT:    slli ra, ra, 12
+; CHECK-NEXT:    addi ra, ra, -1281
+; CHECK-NEXT:    jalr ra
+; CHECK-NEXT:    mv a0, s1
+; CHECK-NEXT:    ld s0, 8(sp) # 8-byte Folded Reload
+; CHECK-NEXT:    ld s1, 0(sp) # 8-byte Folded Reload
+; CHECK-NEXT:    addi sp, sp, 16
+; CHECK-NEXT:    ret
+entry:
+  %resolveCall2 = inttoptr i64 244837814094590 to i8*
+  %result = tail call i64 (i64, i32, i8*, i32, ...) @llvm.experimental.patchpoint.i64(i64 2, i32 28, i8* %resolveCall2, i32 4, i64 %p1, i64 %p2, i64 %p3, i64 %p4)
+  %resolveCall3 = inttoptr i64 244837814094591 to i8*
+  tail call void (i64, i32, i8*, i32, ...) @llvm.experimental.patchpoint.void(i64 3, i32 28, i8* %resolveCall3, i32 2, i64 %p1, i64 %result)
+  ret i64 %result
+}
+
 ; Test small patchpoints that don't emit calls.
 define void @small_patchpoint_codegen(i64 %p1, i64 %p2, i64 %p3, i64 %p4) {
 ; CHECK-LABEL: small_patchpoint_codegen:
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    .cfi_def_cfa_offset 0
-; CHECK-NEXT:  .Ltmp0:
+; CHECK-NEXT:  .Ltmp2:
 ; CHECK-NEXT:    nop
 ; CHECK-NEXT:    nop
 ; CHECK-NEXT:    nop

>From 4c71d076350a3888a4f47e72cb443645f876c63d Mon Sep 17 00:00:00 2001
From: Sacha Coppey <sacha.coppey at oracle.com>
Date: Mon, 12 Feb 2024 19:55:58 +0100
Subject: [PATCH 2/2] Small fixes

---
 llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.cpp | 12 ++++++------
 llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp          |  3 +--
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp           |  4 +---
 3 files changed, 8 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.cpp
index d873d03d4b11c1..49e4bb93669adf 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.cpp
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.cpp
@@ -474,30 +474,30 @@ SmallVector<MCInst, 8>
 generateMCInstSeq(int64_t Val, const MCSubtargetInfo &STI, MCRegister DestReg) {
   RISCVMatInt::InstSeq Seq = RISCVMatInt::generateInstSeq(Val, STI);
 
-  SmallVector<MCInst, 8> instructions;
+  SmallVector<MCInst, 8> Instructions;
 
   MCRegister SrcReg = RISCV::X0;
   for (RISCVMatInt::Inst &Inst : Seq) {
     switch (Inst.getOpndKind()) {
     case RISCVMatInt::Imm:
-      instructions.push_back(MCInstBuilder(Inst.getOpcode())
+      Instructions.push_back(MCInstBuilder(Inst.getOpcode())
                                  .addReg(DestReg)
                                  .addImm(Inst.getImm()));
       break;
     case RISCVMatInt::RegX0:
-      instructions.push_back(MCInstBuilder(Inst.getOpcode())
+      Instructions.push_back(MCInstBuilder(Inst.getOpcode())
                                  .addReg(DestReg)
                                  .addReg(SrcReg)
                                  .addReg(RISCV::X0));
       break;
     case RISCVMatInt::RegReg:
-      instructions.push_back(MCInstBuilder(Inst.getOpcode())
+      Instructions.push_back(MCInstBuilder(Inst.getOpcode())
                                  .addReg(DestReg)
                                  .addReg(SrcReg)
                                  .addReg(SrcReg));
       break;
     case RISCVMatInt::RegImm:
-      instructions.push_back(MCInstBuilder(Inst.getOpcode())
+      Instructions.push_back(MCInstBuilder(Inst.getOpcode())
                                  .addReg(DestReg)
                                  .addReg(SrcReg)
                                  .addImm(Inst.getImm()));
@@ -507,7 +507,7 @@ generateMCInstSeq(int64_t Val, const MCSubtargetInfo &STI, MCRegister DestReg) {
     // Only the first instruction has X0 as its source.
     SrcReg = DestReg;
   }
-  return instructions;
+  return Instructions;
 }
 
 int getIntMatCost(const APInt &Val, unsigned Size, const MCSubtargetInfo &STI,
diff --git a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
index c8975a9cdc4110..80fd48eb0eeddd 100644
--- a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
+++ b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
@@ -164,9 +164,8 @@ void RISCVAsmPrinter::LowerPATCHPOINT(MCStreamer &OutStreamer, StackMaps &SM,
       // Materialize the jump address:
       SmallVector<MCInst, 8> Seq =
           RISCVMatInt::generateMCInstSeq(CallTarget, *STI, RISCV::X1);
-      for (MCInst &Inst : Seq) {
+      for (MCInst &Inst : Seq)
         EmitToStreamer(OutStreamer, Inst);
-      }
       EncodedBytes += Seq.size() * 4;
       EmitToStreamer(OutStreamer, MCInstBuilder(RISCV::JALR)
                                       .addReg(RISCV::X1)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 37cd378cb0eaed..a31c164434757c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1528,9 +1528,7 @@ unsigned RISCVInstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
     // The size of the statepoint intrinsic is the number of bytes requested
     unsigned NumBytes = StatepointOpers(&MI).getNumPatchBytes();
     // A statepoint is at least a PseudoCALL
-    if (NumBytes < 8)
-      NumBytes = 8;
-    return NumBytes;
+    return std::max(NumBytes, 8U);
   }
   default:
     return get(Opcode).getSize();



More information about the llvm-commits mailing list