[llvm] [SDAG] Merge multiple-result libcall expansion into DAG.expandMultipleResultFPLibCall() (PR #114792)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 5 07:36:06 PST 2024


================
@@ -2506,78 +2505,115 @@ bool SelectionDAG::expandFSINCOS(SDNode *Node,
     return nullptr;
   };
 
+  // For vector types, we must find a vector mapping for the libcall.
   VecDesc const *VD = nullptr;
   if (VT.isVector() && !(VD = getVecDesc()))
     return false;
 
   // Find users of the node that store the results (and share input chains). The
   // destination pointers can be used instead of creating stack allocations.
   SDValue StoresInChain{};
-  std::array<StoreSDNode *, 2> ResultStores = {nullptr};
+  SmallVector<StoreSDNode *, 2> ResultStores(NumResults);
   for (SDNode *User : Node->uses()) {
     if (!ISD::isNormalStore(User))
       continue;
     auto *ST = cast<StoreSDNode>(User);
-    if (!ST->isSimple() || ST->getAddressSpace() != 0 ||
-        ST->getAlign() < getDataLayout().getABITypeAlign(Ty->getScalarType()) ||
+    SDValue StoreValue = ST->getValue();
+    unsigned ResNo = StoreValue.getResNo();
+    Type *StoreType = StoreValue.getValueType().getTypeForEVT(Ctx);
+    if (CallRetResNo == ResNo || !ST->isSimple() ||
+        ST->getAddressSpace() != 0 ||
+        ST->getAlign() <
+            getDataLayout().getABITypeAlign(StoreType->getScalarType()) ||
         (StoresInChain && ST->getChain() != StoresInChain) ||
         Node->isPredecessorOf(ST->getChain().getNode()))
       continue;
-    ResultStores[ST->getValue().getResNo()] = ST;
+    ResultStores[ResNo] = ST;
     StoresInChain = ST->getChain();
   }
 
   TargetLowering::ArgListTy Args;
-  TargetLowering::ArgListEntry Entry{};
+  auto AddArgListEntry = [&](SDValue Node, Type *Ty) {
+    TargetLowering::ArgListEntry Entry{};
+    Entry.Ty = Ty;
+    Entry.Node = Node;
+    Args.push_back(Entry);
+  };
 
-  // Pass the argument.
-  Entry.Node = Node->getOperand(0);
-  Entry.Ty = Ty;
-  Args.push_back(Entry);
+  // Pass the arguments.
+  for (const SDValue &Op : Node->op_values()) {
+    EVT ArgVT = Op.getValueType();
+    Type *ArgTy = ArgVT.getTypeForEVT(Ctx);
+    AddArgListEntry(Op, ArgTy);
+  }
 
-  // Pass the output pointers for sin and cos.
-  SmallVector<SDValue, 2> ResultPtrs{};
-  for (StoreSDNode *ST : ResultStores) {
-    SDValue ResultPtr = ST ? ST->getBasePtr() : CreateStackTemporary(VT);
-    Entry.Node = ResultPtr;
-    Entry.Ty = PointerType::getUnqual(Ty->getContext());
-    Args.push_back(Entry);
-    ResultPtrs.push_back(ResultPtr);
+  // Pass the output pointers.
+  SmallVector<SDValue, 2> ResultPtrs(NumResults);
+  Type *PointerTy = PointerType::getUnqual(Ctx);
+  for (auto [ResNo, ST] : llvm::enumerate(ResultStores)) {
+    if (ResNo == CallRetResNo)
+      continue;
+    EVT ResVT = Node->getValueType(ResNo);
+    SDValue ResultPtr = ST ? ST->getBasePtr() : CreateStackTemporary(ResVT);
+    ResultPtrs[ResNo] = ResultPtr;
+    AddArgListEntry(ResultPtr, PointerTy);
   }
 
   SDLoc DL(Node);
 
+  // Pass the vector mask (if required).
   if (VD && VD->isMasked()) {
-    EVT MaskVT = TLI->getSetCCResultType(getDataLayout(), *Ctx, VT);
-    Entry.Node = getBoolConstant(true, DL, MaskVT, VT);
-    Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
-    Args.push_back(Entry);
+    EVT MaskVT = TLI->getSetCCResultType(getDataLayout(), Ctx, VT);
+    SDValue Mask = getBoolConstant(true, DL, MaskVT, VT);
+    AddArgListEntry(Mask, MaskVT.getTypeForEVT(Ctx));
   }
 
+  Type *RetType = CallRetResNo.has_value()
+                      ? Node->getValueType(*CallRetResNo).getTypeForEVT(Ctx)
+                      : Type::getVoidTy(Ctx);
   SDValue InChain = StoresInChain ? StoresInChain : getEntryNode();
   SDValue Callee = getExternalSymbol(VD ? VD->getVectorFnName().data() : LCName,
                                      TLI->getPointerTy(getDataLayout()));
   TargetLowering::CallLoweringInfo CLI(*this);
   CLI.setDebugLoc(DL).setChain(InChain).setLibCallee(
-      TLI->getLibcallCallingConv(LC), Type::getVoidTy(*Ctx), Callee,
-      std::move(Args));
+      TLI->getLibcallCallingConv(LC), RetType, Callee, std::move(Args));
 
-  auto [Call, OutChain] = TLI->LowerCallTo(CLI);
+  auto [Call, CallChain] = TLI->LowerCallTo(CLI);
 
   for (auto [ResNo, ResultPtr] : llvm::enumerate(ResultPtrs)) {
+    if (ResNo == CallRetResNo) {
+      Results.push_back(Call);
+      continue;
+    }
     MachinePointerInfo PtrInfo;
     if (StoreSDNode *ST = ResultStores[ResNo]) {
       // Replace store with the library call.
-      ReplaceAllUsesOfValueWith(SDValue(ST, 0), OutChain);
+      ReplaceAllUsesOfValueWith(SDValue(ST, 0), CallChain);
       PtrInfo = ST->getPointerInfo();
     } else {
       PtrInfo = MachinePointerInfo::getFixedStack(
           getMachineFunction(), cast<FrameIndexSDNode>(ResultPtr)->getIndex());
     }
-    SDValue LoadResult = getLoad(VT, DL, OutChain, ResultPtr, PtrInfo);
+    SDValue LoadResult =
+        getLoad(Node->getValueType(ResNo), DL, CallChain, ResultPtr, PtrInfo);
     Results.push_back(LoadResult);
   }
 
+  // FIXME: Find a way to avoid updating the root. This is needed for x86, which
+  // uses a floating-point stack. If (for example) the node to be expanded has
+  // two results one floating-point which is returned by the call, and one
+  // integer result, returned via an output pointer. If only the integer result
+  // is used then the `CopyFromReg` for the FP result may be optimized out. This
+  // prevents an FP stack pop from being emitted for it. Setting the root like
+  // this ensures there will be a use of the `CopyFromReg` chain, and ensures
+  // the FP pop will be emitted.
+  SDValue OutputChain =
+      getNode(ISD::TokenFactor, DL, MVT::Other, getRoot(), CallChain);
+  setRoot(OutputChain);
+
+  // Ensure the new root is reachable from the results.
+  Results[0] = getMergeValues({Results[0], OutputChain}, DL);
----------------
MacDue wrote:

Done :+1: 

https://github.com/llvm/llvm-project/pull/114792


More information about the llvm-commits mailing list