[llvm] [RISCV] Stack clash protection for dynamic alloca (PR #122508)

Raphael Moreira Zinsly via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 14 12:30:03 PST 2025


================
@@ -22559,3 +22564,99 @@ unsigned RISCVTargetLowering::getStackProbeSize(const MachineFunction &MF,
   StackProbeSize = alignDown(StackProbeSize, StackAlign.value());
   return StackProbeSize ? StackProbeSize : StackAlign.value();
 }
+
+SDValue RISCVTargetLowering::lowerDYNAMIC_STACKALLOC(SDValue Op,
+                                                     SelectionDAG &DAG) const {
+  MachineFunction &MF = DAG.getMachineFunction();
+  if (!hasInlineStackProbe(MF))
+    return SDValue();
+
+  MVT XLenVT = Subtarget.getXLenVT();
+  // Get the inputs.
+  SDNode *Node = Op.getNode();
+  SDValue Chain = Op.getOperand(0);
+  SDValue Size = Op.getOperand(1);
+
+  MaybeAlign Align =
+      cast<ConstantSDNode>(Op.getOperand(2))->getMaybeAlignValue();
+  SDLoc dl(Op);
+  EVT VT = Node->getValueType(0);
+
+  // Construct the new SP value in a GPR.
+  SDValue SP = DAG.getCopyFromReg(Chain, dl, RISCV::X2, XLenVT);
+  Chain = SP.getValue(1);
+  SP = DAG.getNode(ISD::SUB, dl, XLenVT, SP, Size);
+  if (Align)
+    SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
+                     DAG.getSignedConstant(-(uint64_t)Align->value(), dl, VT));
+
+  // Set the real SP to the new value with a probing loop.
+  Chain = DAG.getNode(RISCVISD::PROBED_ALLOCA, dl, MVT::Other, Chain, SP);
+  return DAG.getMergeValues({SP, Chain}, dl);
+}
+
+MachineBasicBlock *
+RISCVTargetLowering::emitDynamicProbedAlloc(MachineInstr &MI,
+                                            MachineBasicBlock *MBB) const {
+  MachineFunction &MF = *MBB->getParent();
+  MachineBasicBlock::iterator MBBI = MI.getIterator();
+  DebugLoc DL = MBB->findDebugLoc(MBBI);
+  Register TargetReg = MI.getOperand(1).getReg();
+
+  const RISCVInstrInfo *TII = Subtarget.getInstrInfo();
+  bool IsRV64 = Subtarget.is64Bit();
+  Align StackAlign = Subtarget.getFrameLowering()->getStackAlign();
+  const RISCVTargetLowering *TLI = Subtarget.getTargetLowering();
+  uint64_t ProbeSize = TLI->getStackProbeSize(MF, StackAlign);
+
+  MachineFunction::iterator MBBInsertPoint = std::next(MBB->getIterator());
+  MachineBasicBlock *LoopTestMBB =
+      MF.CreateMachineBasicBlock(MBB->getBasicBlock());
+  MF.insert(MBBInsertPoint, LoopTestMBB);
+  MachineBasicBlock *ExitMBB = MF.CreateMachineBasicBlock(MBB->getBasicBlock());
+  MF.insert(MBBInsertPoint, ExitMBB);
+  MachineInstr::MIFlag Flags = MachineInstr::FrameSetup;
----------------
rzinsly wrote:

I removed the flag and added a new member to `RISCVMachineFunctionInfo` to trigger the probe,  no need to search for BBs in this way.

https://github.com/llvm/llvm-project/pull/122508


More information about the llvm-commits mailing list