[llvm] [SDAG] Merge multiple-result libcall expansion into DAG.expandMultipleResultFPLibCall() (PR #114792)
Benjamin Maxwell via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 5 07:36:06 PST 2024
================
@@ -2506,78 +2505,115 @@ bool SelectionDAG::expandFSINCOS(SDNode *Node,
return nullptr;
};
+ // For vector types, we must find a vector mapping for the libcall.
VecDesc const *VD = nullptr;
if (VT.isVector() && !(VD = getVecDesc()))
return false;
// Find users of the node that store the results (and share input chains). The
// destination pointers can be used instead of creating stack allocations.
SDValue StoresInChain{};
- std::array<StoreSDNode *, 2> ResultStores = {nullptr};
+ SmallVector<StoreSDNode *, 2> ResultStores(NumResults);
for (SDNode *User : Node->uses()) {
if (!ISD::isNormalStore(User))
continue;
auto *ST = cast<StoreSDNode>(User);
- if (!ST->isSimple() || ST->getAddressSpace() != 0 ||
- ST->getAlign() < getDataLayout().getABITypeAlign(Ty->getScalarType()) ||
+ SDValue StoreValue = ST->getValue();
+ unsigned ResNo = StoreValue.getResNo();
+ Type *StoreType = StoreValue.getValueType().getTypeForEVT(Ctx);
+ if (CallRetResNo == ResNo || !ST->isSimple() ||
+ ST->getAddressSpace() != 0 ||
+ ST->getAlign() <
+ getDataLayout().getABITypeAlign(StoreType->getScalarType()) ||
(StoresInChain && ST->getChain() != StoresInChain) ||
Node->isPredecessorOf(ST->getChain().getNode()))
continue;
- ResultStores[ST->getValue().getResNo()] = ST;
+ ResultStores[ResNo] = ST;
StoresInChain = ST->getChain();
}
TargetLowering::ArgListTy Args;
- TargetLowering::ArgListEntry Entry{};
+ auto AddArgListEntry = [&](SDValue Node, Type *Ty) {
+ TargetLowering::ArgListEntry Entry{};
+ Entry.Ty = Ty;
+ Entry.Node = Node;
+ Args.push_back(Entry);
+ };
- // Pass the argument.
- Entry.Node = Node->getOperand(0);
- Entry.Ty = Ty;
- Args.push_back(Entry);
+ // Pass the arguments.
+ for (const SDValue &Op : Node->op_values()) {
+ EVT ArgVT = Op.getValueType();
+ Type *ArgTy = ArgVT.getTypeForEVT(Ctx);
+ AddArgListEntry(Op, ArgTy);
+ }
- // Pass the output pointers for sin and cos.
- SmallVector<SDValue, 2> ResultPtrs{};
- for (StoreSDNode *ST : ResultStores) {
- SDValue ResultPtr = ST ? ST->getBasePtr() : CreateStackTemporary(VT);
- Entry.Node = ResultPtr;
- Entry.Ty = PointerType::getUnqual(Ty->getContext());
- Args.push_back(Entry);
- ResultPtrs.push_back(ResultPtr);
+ // Pass the output pointers.
+ SmallVector<SDValue, 2> ResultPtrs(NumResults);
+ Type *PointerTy = PointerType::getUnqual(Ctx);
+ for (auto [ResNo, ST] : llvm::enumerate(ResultStores)) {
+ if (ResNo == CallRetResNo)
+ continue;
+ EVT ResVT = Node->getValueType(ResNo);
+ SDValue ResultPtr = ST ? ST->getBasePtr() : CreateStackTemporary(ResVT);
+ ResultPtrs[ResNo] = ResultPtr;
+ AddArgListEntry(ResultPtr, PointerTy);
}
SDLoc DL(Node);
+ // Pass the vector mask (if required).
if (VD && VD->isMasked()) {
- EVT MaskVT = TLI->getSetCCResultType(getDataLayout(), *Ctx, VT);
- Entry.Node = getBoolConstant(true, DL, MaskVT, VT);
- Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
- Args.push_back(Entry);
+ EVT MaskVT = TLI->getSetCCResultType(getDataLayout(), Ctx, VT);
+ SDValue Mask = getBoolConstant(true, DL, MaskVT, VT);
+ AddArgListEntry(Mask, MaskVT.getTypeForEVT(Ctx));
}
+ Type *RetType = CallRetResNo.has_value()
+ ? Node->getValueType(*CallRetResNo).getTypeForEVT(Ctx)
+ : Type::getVoidTy(Ctx);
SDValue InChain = StoresInChain ? StoresInChain : getEntryNode();
SDValue Callee = getExternalSymbol(VD ? VD->getVectorFnName().data() : LCName,
TLI->getPointerTy(getDataLayout()));
TargetLowering::CallLoweringInfo CLI(*this);
CLI.setDebugLoc(DL).setChain(InChain).setLibCallee(
- TLI->getLibcallCallingConv(LC), Type::getVoidTy(*Ctx), Callee,
- std::move(Args));
+ TLI->getLibcallCallingConv(LC), RetType, Callee, std::move(Args));
- auto [Call, OutChain] = TLI->LowerCallTo(CLI);
+ auto [Call, CallChain] = TLI->LowerCallTo(CLI);
for (auto [ResNo, ResultPtr] : llvm::enumerate(ResultPtrs)) {
+ if (ResNo == CallRetResNo) {
+ Results.push_back(Call);
+ continue;
+ }
MachinePointerInfo PtrInfo;
if (StoreSDNode *ST = ResultStores[ResNo]) {
// Replace store with the library call.
- ReplaceAllUsesOfValueWith(SDValue(ST, 0), OutChain);
+ ReplaceAllUsesOfValueWith(SDValue(ST, 0), CallChain);
PtrInfo = ST->getPointerInfo();
} else {
PtrInfo = MachinePointerInfo::getFixedStack(
getMachineFunction(), cast<FrameIndexSDNode>(ResultPtr)->getIndex());
}
- SDValue LoadResult = getLoad(VT, DL, OutChain, ResultPtr, PtrInfo);
+ SDValue LoadResult =
+ getLoad(Node->getValueType(ResNo), DL, CallChain, ResultPtr, PtrInfo);
Results.push_back(LoadResult);
}
+ // FIXME: Find a way to avoid updating the root. This is needed for x86, which
+ // uses a floating-point stack. If (for example) the node to be expanded has
+ // two results one floating-point which is returned by the call, and one
+ // integer result, returned via an output pointer. If only the integer result
+ // is used then the `CopyFromReg` for the FP result may be optimized out. This
+ // prevents an FP stack pop from being emitted for it. Setting the root like
+ // this ensures there will be a use of the `CopyFromReg` chain, and ensures
+ // the FP pop will be emitted.
+ SDValue OutputChain =
+ getNode(ISD::TokenFactor, DL, MVT::Other, getRoot(), CallChain);
+ setRoot(OutputChain);
+
+ // Ensure the new root is reachable from the results.
+ Results[0] = getMergeValues({Results[0], OutputChain}, DL);
----------------
MacDue wrote:
Done :+1:
https://github.com/llvm/llvm-project/pull/114792
More information about the llvm-commits
mailing list