[llvm] [RISCV] Fix musttail with indirect arguments by forwarding incoming pointers (PR #185094)

Xavier Roche via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 6 23:29:10 PST 2026


================
@@ -24765,51 +24780,70 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
     // Promote the value if needed.
     // For now, only handle fully promoted and indirect arguments.
     if (VA.getLocInfo() == CCValAssign::Indirect) {
-      // Store the argument in a stack slot and pass its address.
-      Align StackAlign =
-          std::max(getPrefTypeAlign(Outs[OutIdx].ArgVT, DAG),
-                   getPrefTypeAlign(ArgValue.getValueType(), DAG));
-      TypeSize StoredSize = ArgValue.getValueType().getStoreSize();
-      // If the original argument was split (e.g. i128), we need
-      // to store the required parts of it here (and pass just one address).
-      // Vectors may be partly split to registers and partly to the stack, in
-      // which case the base address is partly offset and subsequent stores are
-      // relative to that.
-      unsigned ArgIndex = Outs[OutIdx].OrigArgIndex;
-      unsigned ArgPartOffset = Outs[OutIdx].PartOffset;
-      assert(VA.getValVT().isVector() || ArgPartOffset == 0);
-      // Calculate the total size to store. We don't have access to what we're
-      // actually storing other than performing the loop and collecting the
-      // info.
-      SmallVector<std::pair<SDValue, SDValue>> Parts;
-      while (i + 1 != e && Outs[OutIdx + 1].OrigArgIndex == ArgIndex) {
-        SDValue PartValue = OutVals[OutIdx + 1];
-        unsigned PartOffset = Outs[OutIdx + 1].PartOffset - ArgPartOffset;
-        SDValue Offset = DAG.getIntPtrConstant(PartOffset, DL);
-        EVT PartVT = PartValue.getValueType();
-        if (PartVT.isScalableVector())
-          Offset = DAG.getNode(ISD::VSCALE, DL, XLenVT, Offset);
-        StoredSize += PartVT.getStoreSize();
-        StackAlign = std::max(StackAlign, getPrefTypeAlign(PartVT, DAG));
-        Parts.push_back(std::make_pair(PartValue, Offset));
-        ++i;
-        ++OutIdx;
-      }
-      SDValue SpillSlot = DAG.CreateStackTemporary(StoredSize, StackAlign);
-      int FI = cast<FrameIndexSDNode>(SpillSlot)->getIndex();
-      MemOpChains.push_back(
-          DAG.getStore(Chain, DL, ArgValue, SpillSlot,
-                       MachinePointerInfo::getFixedStack(MF, FI)));
-      for (const auto &Part : Parts) {
-        SDValue PartValue = Part.first;
-        SDValue PartOffset = Part.second;
-        SDValue Address =
-            DAG.getNode(ISD::ADD, DL, PtrVT, SpillSlot, PartOffset);
+      // For musttail calls, forward the incoming indirect pointer instead
+      // of creating a new stack temporary. The incoming pointer points to
+      // the caller's caller's frame, which remains valid after a tail call.
+      if (IsTailCall && CLI.CB && CLI.CB->isMustTailCall()) {
+        unsigned IndirectIdx = 0;
+        for (unsigned k = 0; k < OutIdx; ++k) {
+          if (ArgLocs[k].getLocInfo() == CCValAssign::Indirect)
+            ++IndirectIdx;
+        }
+        ArgValue = RVFI->getIncomingIndirectArg(IndirectIdx);
+        // Skip any split parts of this argument (they are covered by the
+        // forwarded pointer).
+        unsigned ArgIndex = Outs[OutIdx].OrigArgIndex;
+        while (i + 1 != e && Outs[OutIdx + 1].OrigArgIndex == ArgIndex) {
+          ++i;
+          ++OutIdx;
+        }
+      } else {
----------------
xroche wrote:

My understanding is that for non-musttail, detecting that a tail call forwards an indirect arg from the caller is a tad more complicated (tracing the DAG value back to determine whether it originated from a indirect formal argument)

The existing check in isEligibleForTailCallOptimization should handle non-musttail cases.
(And for musttail, the LLVM IR verifier guarantees matching prototypes.)

I've added a test showing that a non-musttail tail call forwarding an indirect arg correctly falls back to call.

However, this is where my lack of expertise is visible; I will add more test coverage for mixed direct and indirect, multiple indirect (because of the use of the map), and musttail with int128+something else.

https://github.com/llvm/llvm-project/pull/185094


More information about the llvm-commits mailing list