[llvm] [RISCV] Stack clash protection for dynamic alloca (PR #122508)
Raphael Moreira Zinsly via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 14 12:30:03 PST 2025
================
@@ -22559,3 +22564,99 @@ unsigned RISCVTargetLowering::getStackProbeSize(const MachineFunction &MF,
StackProbeSize = alignDown(StackProbeSize, StackAlign.value());
return StackProbeSize ? StackProbeSize : StackAlign.value();
}
+
+SDValue RISCVTargetLowering::lowerDYNAMIC_STACKALLOC(SDValue Op,
+ SelectionDAG &DAG) const {
+ MachineFunction &MF = DAG.getMachineFunction();
+ if (!hasInlineStackProbe(MF))
+ return SDValue();
+
+ MVT XLenVT = Subtarget.getXLenVT();
+ // Get the inputs.
+ SDNode *Node = Op.getNode();
+ SDValue Chain = Op.getOperand(0);
+ SDValue Size = Op.getOperand(1);
+
+ MaybeAlign Align =
+ cast<ConstantSDNode>(Op.getOperand(2))->getMaybeAlignValue();
+ SDLoc dl(Op);
+ EVT VT = Node->getValueType(0);
+
+ // Construct the new SP value in a GPR.
+ SDValue SP = DAG.getCopyFromReg(Chain, dl, RISCV::X2, XLenVT);
+ Chain = SP.getValue(1);
+ SP = DAG.getNode(ISD::SUB, dl, XLenVT, SP, Size);
+ if (Align)
+ SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
+ DAG.getSignedConstant(-(uint64_t)Align->value(), dl, VT));
+
+ // Set the real SP to the new value with a probing loop.
+ Chain = DAG.getNode(RISCVISD::PROBED_ALLOCA, dl, MVT::Other, Chain, SP);
+ return DAG.getMergeValues({SP, Chain}, dl);
+}
+
+MachineBasicBlock *
+RISCVTargetLowering::emitDynamicProbedAlloc(MachineInstr &MI,
+ MachineBasicBlock *MBB) const {
+ MachineFunction &MF = *MBB->getParent();
+ MachineBasicBlock::iterator MBBI = MI.getIterator();
+ DebugLoc DL = MBB->findDebugLoc(MBBI);
+ Register TargetReg = MI.getOperand(1).getReg();
+
+ const RISCVInstrInfo *TII = Subtarget.getInstrInfo();
+ bool IsRV64 = Subtarget.is64Bit();
+ Align StackAlign = Subtarget.getFrameLowering()->getStackAlign();
+ const RISCVTargetLowering *TLI = Subtarget.getTargetLowering();
+ uint64_t ProbeSize = TLI->getStackProbeSize(MF, StackAlign);
+
+ MachineFunction::iterator MBBInsertPoint = std::next(MBB->getIterator());
+ MachineBasicBlock *LoopTestMBB =
+ MF.CreateMachineBasicBlock(MBB->getBasicBlock());
+ MF.insert(MBBInsertPoint, LoopTestMBB);
+ MachineBasicBlock *ExitMBB = MF.CreateMachineBasicBlock(MBB->getBasicBlock());
+ MF.insert(MBBInsertPoint, ExitMBB);
+ MachineInstr::MIFlag Flags = MachineInstr::FrameSetup;
----------------
rzinsly wrote:
I removed the flag and added a new member to `RISCVMachineFunctionInfo` to trigger the probe, no need to search for BBs in this way.
https://github.com/llvm/llvm-project/pull/122508
More information about the llvm-commits
mailing list