[llvm] 9d469b5 - [RISCV] Implement trampolines for rv64 (#96309)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 17 23:06:50 PDT 2024
Author: Roger Ferrer Ibáñez
Date: 2024-10-18T08:06:47+02:00
New Revision: 9d469b5988bfb1c2e99533f863b1f9eb5b0c58b7
URL: https://github.com/llvm/llvm-project/commit/9d469b5988bfb1c2e99533f863b1f9eb5b0c58b7
DIFF: https://github.com/llvm/llvm-project/commit/9d469b5988bfb1c2e99533f863b1f9eb5b0c58b7.diff
LOG: [RISCV] Implement trampolines for rv64 (#96309)
This is implementation is based on what the X86 target does but
emitting the instructions that GCC emits for rv64.
---------
Co-authored-by: Pengcheng Wang <wangpengcheng.pp at bytedance.com>
Added:
llvm/test/CodeGen/RISCV/rv64-trampoline.ll
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 952072c26739f9..fa157ca48db21b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -37,6 +37,8 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsRISCV.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/MC/MCCodeEmitter.h"
+#include "llvm/MC/MCInstBuilder.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
@@ -625,6 +627,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::READSTEADYCOUNTER, MVT::i64,
Subtarget.is64Bit() ? Legal : Custom);
+ if (Subtarget.is64Bit()) {
+ setOperationAction(ISD::INIT_TRAMPOLINE, MVT::Other, Custom);
+ setOperationAction(ISD::ADJUST_TRAMPOLINE, MVT::Other, Custom);
+ }
+
setOperationAction({ISD::TRAP, ISD::DEBUGTRAP}, MVT::Other, Legal);
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
if (Subtarget.is64Bit())
@@ -7402,6 +7409,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return emitFlushICache(DAG, Op.getOperand(0), Op.getOperand(1),
Op.getOperand(2), Flags, DL);
}
+ case ISD::INIT_TRAMPOLINE:
+ return lowerINIT_TRAMPOLINE(Op, DAG);
+ case ISD::ADJUST_TRAMPOLINE:
+ return lowerADJUST_TRAMPOLINE(Op, DAG);
}
}
@@ -7417,6 +7428,126 @@ SDValue RISCVTargetLowering::emitFlushICache(SelectionDAG &DAG, SDValue InChain,
return CallResult.second;
}
+SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
+ SelectionDAG &DAG) const {
+ if (!Subtarget.is64Bit())
+ llvm::report_fatal_error("Trampolines only implemented for RV64");
+
+ // Create an MCCodeEmitter to encode instructions.
+ TargetLoweringObjectFile *TLO = getTargetMachine().getObjFileLowering();
+ assert(TLO);
+ MCContext &MCCtx = TLO->getContext();
+
+ std::unique_ptr<MCCodeEmitter> CodeEmitter(
+ createRISCVMCCodeEmitter(*getTargetMachine().getMCInstrInfo(), MCCtx));
+
+ SDValue Root = Op.getOperand(0);
+ SDValue Trmp = Op.getOperand(1); // trampoline
+ SDLoc dl(Op);
+
+ const Value *TrmpAddr = cast<SrcValueSDNode>(Op.getOperand(4))->getValue();
+
+ // We store in the trampoline buffer the following instructions and data.
+ // Offset:
+ // 0: auipc t2, 0
+ // 4: ld t0, 24(t2)
+ // 8: ld t2, 16(t2)
+ // 12: jalr t0
+ // 16: <StaticChainOffset>
+ // 24: <FunctionAddressOffset>
+ // 32:
+
+ constexpr unsigned StaticChainOffset = 16;
+ constexpr unsigned FunctionAddressOffset = 24;
+
+ const MCSubtargetInfo *STI = getTargetMachine().getMCSubtargetInfo();
+ assert(STI);
+ auto GetEncoding = [&](const MCInst &MC) {
+ SmallVector<char, 4> CB;
+ SmallVector<MCFixup> Fixups;
+ CodeEmitter->encodeInstruction(MC, CB, Fixups, *STI);
+ uint32_t Encoding = support::endian::read32le(CB.data());
+ return Encoding;
+ };
+
+ SDValue OutChains[6];
+
+ uint32_t Encodings[] = {
+ // auipc t2, 0
+ // Loads the current PC into t2.
+ GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
+ // ld t0, 24(t2)
+ // Loads the function address into t0. Note that we are using offsets
+ // pc-relative to the first instruction of the trampoline.
+ GetEncoding(
+ MCInstBuilder(RISCV::LD).addReg(RISCV::X5).addReg(RISCV::X7).addImm(
+ FunctionAddressOffset)),
+ // ld t2, 16(t2)
+ // Load the value of the static chain.
+ GetEncoding(
+ MCInstBuilder(RISCV::LD).addReg(RISCV::X7).addReg(RISCV::X7).addImm(
+ StaticChainOffset)),
+ // jalr t0
+ // Jump to the function.
+ GetEncoding(MCInstBuilder(RISCV::JALR)
+ .addReg(RISCV::X0)
+ .addReg(RISCV::X5)
+ .addImm(0))};
+
+ // Store encoded instructions.
+ for (auto [Idx, Encoding] : llvm::enumerate(Encodings)) {
+ SDValue Addr = Idx > 0 ? DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
+ DAG.getConstant(Idx * 4, dl, MVT::i64))
+ : Trmp;
+ OutChains[Idx] = DAG.getTruncStore(
+ Root, dl, DAG.getConstant(Encoding, dl, MVT::i64), Addr,
+ MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32);
+ }
+
+ // Now store the variable part of the trampoline.
+ SDValue FunctionAddress = Op.getOperand(2);
+ SDValue StaticChain = Op.getOperand(3);
+
+ // Store the given static chain and function pointer in the trampoline buffer.
+ struct OffsetValuePair {
+ const unsigned Offset;
+ const SDValue Value;
+ SDValue Addr = SDValue(); // Used to cache the address.
+ } OffsetValues[] = {
+ {StaticChainOffset, StaticChain},
+ {FunctionAddressOffset, FunctionAddress},
+ };
+ for (auto [Idx, OffsetValue] : llvm::enumerate(OffsetValues)) {
+ SDValue Addr =
+ DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
+ DAG.getConstant(OffsetValue.Offset, dl, MVT::i64));
+ OffsetValue.Addr = Addr;
+ OutChains[Idx + 4] =
+ DAG.getStore(Root, dl, OffsetValue.Value, Addr,
+ MachinePointerInfo(TrmpAddr, OffsetValue.Offset));
+ }
+
+ SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
+
+ // The end of instructions of trampoline is the same as the static chain
+ // address that we computed earlier.
+ SDValue EndOfTrmp = OffsetValues[0].Addr;
+
+ // Call clear cache on the trampoline instructions.
+ SDValue Chain = DAG.getNode(ISD::CLEAR_CACHE, dl, MVT::Other, StoreToken,
+ Trmp, EndOfTrmp);
+
+ return Chain;
+}
+
+SDValue RISCVTargetLowering::lowerADJUST_TRAMPOLINE(SDValue Op,
+ SelectionDAG &DAG) const {
+ if (!Subtarget.is64Bit())
+ llvm::report_fatal_error("Trampolines only implemented for RV64");
+
+ return Op.getOperand(0);
+}
+
static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty,
SelectionDAG &DAG, unsigned Flags) {
return DAG.getTargetGlobalAddress(N->getGlobal(), DL, Ty, 0, Flags);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 3864d58a129e98..c3749447955330 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -992,6 +992,9 @@ class RISCVTargetLowering : public TargetLowering {
SDValue expandUnalignedRVVLoad(SDValue Op, SelectionDAG &DAG) const;
SDValue expandUnalignedRVVStore(SDValue Op, SelectionDAG &DAG) const;
+ SDValue lowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
+ SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
+
bool isEligibleForTailCallOptimization(
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
const SmallVector<CCValAssign, 16> &ArgLocs) const;
diff --git a/llvm/test/CodeGen/RISCV/rv64-trampoline.ll b/llvm/test/CodeGen/RISCV/rv64-trampoline.ll
new file mode 100644
index 00000000000000..ba184063265098
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rv64-trampoline.ll
@@ -0,0 +1,80 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s \
+; RUN: | FileCheck -check-prefix=RV64 %s
+; RUN: llc -mtriple=riscv64-unknown-linux-gnu -verify-machineinstrs < %s \
+; RUN: | FileCheck -check-prefix=RV64-LINUX %s
+
+declare void @llvm.init.trampoline(ptr, ptr, ptr)
+declare ptr @llvm.adjust.trampoline(ptr)
+declare i64 @f(ptr nest, i64)
+
+define i64 @test0(i64 %n, ptr %p) nounwind {
+; RV64-LABEL: test0:
+; RV64: # %bb.0:
+; RV64-NEXT: addi sp, sp, -64
+; RV64-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
+; RV64-NEXT: sd s0, 48(sp) # 8-byte Folded Spill
+; RV64-NEXT: sd s1, 40(sp) # 8-byte Folded Spill
+; RV64-NEXT: mv s0, a0
+; RV64-NEXT: lui a0, %hi(f)
+; RV64-NEXT: addi a0, a0, %lo(f)
+; RV64-NEXT: sd a0, 32(sp)
+; RV64-NEXT: li a0, 919
+; RV64-NEXT: lui a2, %hi(.LCPI0_0)
+; RV64-NEXT: ld a2, %lo(.LCPI0_0)(a2)
+; RV64-NEXT: lui a3, 6203
+; RV64-NEXT: addi a3, a3, 643
+; RV64-NEXT: sw a0, 8(sp)
+; RV64-NEXT: sw a3, 12(sp)
+; RV64-NEXT: sd a2, 16(sp)
+; RV64-NEXT: sd a1, 24(sp)
+; RV64-NEXT: addi a1, sp, 24
+; RV64-NEXT: addi a0, sp, 8
+; RV64-NEXT: addi s1, sp, 8
+; RV64-NEXT: call __clear_cache
+; RV64-NEXT: mv a0, s0
+; RV64-NEXT: jalr s1
+; RV64-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
+; RV64-NEXT: ld s0, 48(sp) # 8-byte Folded Reload
+; RV64-NEXT: ld s1, 40(sp) # 8-byte Folded Reload
+; RV64-NEXT: addi sp, sp, 64
+; RV64-NEXT: ret
+;
+; RV64-LINUX-LABEL: test0:
+; RV64-LINUX: # %bb.0:
+; RV64-LINUX-NEXT: addi sp, sp, -64
+; RV64-LINUX-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
+; RV64-LINUX-NEXT: sd s0, 48(sp) # 8-byte Folded Spill
+; RV64-LINUX-NEXT: sd s1, 40(sp) # 8-byte Folded Spill
+; RV64-LINUX-NEXT: mv s0, a0
+; RV64-LINUX-NEXT: lui a0, %hi(f)
+; RV64-LINUX-NEXT: addi a0, a0, %lo(f)
+; RV64-LINUX-NEXT: sd a0, 32(sp)
+; RV64-LINUX-NEXT: li a0, 919
+; RV64-LINUX-NEXT: lui a2, %hi(.LCPI0_0)
+; RV64-LINUX-NEXT: ld a2, %lo(.LCPI0_0)(a2)
+; RV64-LINUX-NEXT: lui a3, 6203
+; RV64-LINUX-NEXT: addi a3, a3, 643
+; RV64-LINUX-NEXT: sw a0, 8(sp)
+; RV64-LINUX-NEXT: sw a3, 12(sp)
+; RV64-LINUX-NEXT: sd a2, 16(sp)
+; RV64-LINUX-NEXT: sd a1, 24(sp)
+; RV64-LINUX-NEXT: addi a1, sp, 24
+; RV64-LINUX-NEXT: addi a0, sp, 8
+; RV64-LINUX-NEXT: addi s1, sp, 8
+; RV64-LINUX-NEXT: li a2, 0
+; RV64-LINUX-NEXT: call __riscv_flush_icache
+; RV64-LINUX-NEXT: mv a0, s0
+; RV64-LINUX-NEXT: jalr s1
+; RV64-LINUX-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
+; RV64-LINUX-NEXT: ld s0, 48(sp) # 8-byte Folded Reload
+; RV64-LINUX-NEXT: ld s1, 40(sp) # 8-byte Folded Reload
+; RV64-LINUX-NEXT: addi sp, sp, 64
+; RV64-LINUX-NEXT: ret
+ %alloca = alloca [32 x i8], align 8
+ call void @llvm.init.trampoline(ptr %alloca, ptr @f, ptr %p)
+ %tramp = call ptr @llvm.adjust.trampoline(ptr %alloca)
+ %ret = call i64 %tramp(i64 %n)
+ ret i64 %ret
+
+}
More information about the llvm-commits
mailing list