[llvm] [NVPTX] Further cleanup call isel (PR #146411)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 30 12:35:09 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

<details>
<summary>Changes</summary>



---

Patch is 315.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146411.diff


15 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+150-153) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+10-6) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+23-35) 
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm60.ll (+540-540) 
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm70.ll (+540-540) 
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm90.ll (+540-540) 
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg.ll (+120-120) 
- (modified) llvm/test/CodeGen/NVPTX/convert-int-sm20.ll (+3-3) 
- (modified) llvm/test/CodeGen/NVPTX/extractelement.ll (+9-12) 
- (modified) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+4-4) 
- (modified) llvm/test/CodeGen/NVPTX/lower-args.ll (+2-2) 
- (modified) llvm/test/CodeGen/NVPTX/misched_func_call.ll (+6-6) 
- (modified) llvm/test/CodeGen/NVPTX/st-param-imm.ll (+30-30) 
- (modified) llvm/test/CodeGen/NVPTX/unaligned-param-load-store.ll (+96-96) 
- (modified) llvm/test/CodeGen/NVPTX/variadics-backend.ll (+1-1) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d9192fbfceff1..a41b094faa8d6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -28,6 +28,7 @@
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineMemOperand.h"
+#include "llvm/CodeGen/Register.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/CodeGen/TargetCallingConv.h"
@@ -390,35 +391,27 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
 /// and promote them to a larger size if they're not.
 ///
 /// The promoted type is placed in \p PromoteVT if the function returns true.
-static std::optional<MVT> PromoteScalarIntegerPTX(const EVT &VT) {
+static EVT promoteScalarIntegerPTX(const EVT VT) {
   if (VT.isScalarInteger()) {
-    MVT PromotedVT;
     switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
     default:
       llvm_unreachable(
           "Promotion is not suitable for scalars of size larger than 64-bits");
     case 1:
-      PromotedVT = MVT::i1;
-      break;
+      return MVT::i1;
     case 2:
     case 4:
     case 8:
-      PromotedVT = MVT::i8;
-      break;
+      return MVT::i8;
     case 16:
-      PromotedVT = MVT::i16;
-      break;
+      return MVT::i16;
     case 32:
-      PromotedVT = MVT::i32;
-      break;
+      return MVT::i32;
     case 64:
-      PromotedVT = MVT::i64;
-      break;
+      return MVT::i64;
     }
-    if (VT != PromotedVT)
-      return PromotedVT;
   }
-  return std::nullopt;
+  return VT;
 }
 
 // Check whether we can merge loads/stores of some of the pieces of a
@@ -1053,10 +1046,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     break;
 
     MAKE_CASE(NVPTXISD::RET_GLUE)
-    MAKE_CASE(NVPTXISD::DeclareParam)
+    MAKE_CASE(NVPTXISD::DeclareArrayParam)
     MAKE_CASE(NVPTXISD::DeclareScalarParam)
-    MAKE_CASE(NVPTXISD::DeclareRet)
-    MAKE_CASE(NVPTXISD::DeclareRetParam)
     MAKE_CASE(NVPTXISD::CALL)
     MAKE_CASE(NVPTXISD::LoadParam)
     MAKE_CASE(NVPTXISD::LoadParamV2)
@@ -1162,8 +1153,8 @@ SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
 }
 
 std::string NVPTXTargetLowering::getPrototype(
-    const DataLayout &DL, Type *retTy, const ArgListTy &Args,
-    const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign RetAlign,
+    const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
+    const SmallVectorImpl<ISD::OutputArg> &Outs,
     std::optional<unsigned> FirstVAArg, const CallBase &CB,
     unsigned UniqueCallSite) const {
   auto PtrVT = getPointerTy(DL);
@@ -1172,22 +1163,22 @@ std::string NVPTXTargetLowering::getPrototype(
   raw_string_ostream O(Prototype);
   O << "prototype_" << UniqueCallSite << " : .callprototype ";
 
-  if (retTy->isVoidTy()) {
+  if (RetTy->isVoidTy()) {
     O << "()";
   } else {
     O << "(";
-    if (shouldPassAsArray(retTy)) {
-      assert(RetAlign && "RetAlign must be set for non-void return types");
-      O << ".param .align " << RetAlign->value() << " .b8 _["
-        << DL.getTypeAllocSize(retTy) << "]";
-    } else if (retTy->isFloatingPointTy() || retTy->isIntegerTy()) {
+    if (shouldPassAsArray(RetTy)) {
+      const Align RetAlign = getArgumentAlignment(&CB, RetTy, 0, DL);
+      O << ".param .align " << RetAlign.value() << " .b8 _["
+        << DL.getTypeAllocSize(RetTy) << "]";
+    } else if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy()) {
       unsigned size = 0;
-      if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
+      if (auto *ITy = dyn_cast<IntegerType>(RetTy)) {
         size = ITy->getBitWidth();
       } else {
-        assert(retTy->isFloatingPointTy() &&
+        assert(RetTy->isFloatingPointTy() &&
                "Floating point type expected here");
-        size = retTy->getPrimitiveSizeInBits();
+        size = RetTy->getPrimitiveSizeInBits();
       }
       // PTX ABI requires all scalar return values to be at least 32
       // bits in size.  fp16 normally uses .b16 as its storage type in
@@ -1195,7 +1186,7 @@ std::string NVPTXTargetLowering::getPrototype(
       size = promoteScalarArgumentSize(size);
 
       O << ".param .b" << size << " _";
-    } else if (isa<PointerType>(retTy)) {
+    } else if (isa<PointerType>(RetTy)) {
       O << ".param .b" << PtrVT.getSizeInBits() << " _";
     } else {
       llvm_unreachable("Unknown return type");
@@ -1256,7 +1247,7 @@ std::string NVPTXTargetLowering::getPrototype(
 
   if (FirstVAArg)
     O << (first ? "" : ",") << " .param .align "
-      << STI.getMaxRequiredAlignment() << " .b8 _[]\n";
+      << STI.getMaxRequiredAlignment() << " .b8 _[]";
   O << ")";
   if (shouldEmitPTXNoReturn(&CB, *nvTM))
     O << " .noreturn";
@@ -1442,6 +1433,21 @@ static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
   return ISD::ANY_EXTEND;
 }
 
+static SDValue correctParamType(SDValue V, EVT ExpectedVT,
+                                ISD::ArgFlagsTy Flags, SelectionDAG &DAG,
+                                SDLoc dl) {
+  const EVT ActualVT = V.getValueType();
+  assert((ActualVT == ExpectedVT ||
+          (ExpectedVT.isInteger() && ActualVT.isInteger())) &&
+         "Non-integer argument type size mismatch");
+  if (ExpectedVT.bitsGT(ActualVT))
+    return DAG.getNode(getExtOpcode(Flags), dl, ExpectedVT, V);
+  if (ExpectedVT.bitsLT(ActualVT))
+    return DAG.getNode(ISD::TRUNCATE, dl, ExpectedVT, V);
+
+  return V;
+}
+
 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                                        SmallVectorImpl<SDValue> &InVals) const {
 
@@ -1505,9 +1511,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
          "Outs and OutVals must be the same size");
   // Declare the .params or .reg need to pass values
   // to the function
-  for (const auto [ArgI, Arg] : llvm::enumerate(Args)) {
-    const auto ArgOuts = AllOuts.take_while(
-        [ArgI = ArgI](auto O) { return O.OrigArgIndex == ArgI; });
+  for (const auto E : llvm::enumerate(Args)) {
+    const auto ArgI = E.index();
+    const auto Arg = E.value();
+    const auto ArgOuts =
+        AllOuts.take_while([&](auto O) { return O.OrigArgIndex == ArgI; });
     const auto ArgOutVals = AllOutVals.take_front(ArgOuts.size());
     AllOuts = AllOuts.drop_front(ArgOuts.size());
     AllOutVals = AllOutVals.drop_front(ArgOuts.size());
@@ -1515,6 +1523,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const bool IsVAArg = (ArgI >= FirstVAArg);
     const bool IsByVal = Arg.IsByVal;
 
+    const SDValue ParamSymbol =
+        getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
+
     SmallVector<EVT, 16> VTs;
     SmallVector<uint64_t, 16> Offsets;
 
@@ -1525,38 +1536,43 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     assert(VTs.size() == Offsets.size() && "Size mismatch");
     assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
 
-    Align ArgAlign;
-    if (IsByVal) {
-      // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
-      // so we don't need to worry whether it's naturally aligned or not.
-      // See TargetLowering::LowerCallTo().
-      Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
-      ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
-                                            InitialAlign, DL);
-      if (IsVAArg)
-        VAOffset = alignTo(VAOffset, ArgAlign);
-    } else {
-      ArgAlign = getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
-    }
+    const Align ArgAlign = [&]() {
+      if (IsByVal) {
+        // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
+        // so we don't need to worry whether it's naturally aligned or not.
+        // See TargetLowering::LowerCallTo().
+        const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
+        const Align ByValAlign = getFunctionByValParamAlign(
+            CB->getCalledFunction(), ETy, InitialAlign, DL);
+        if (IsVAArg)
+          VAOffset = alignTo(VAOffset, ByValAlign);
+        return ByValAlign;
+      }
+      return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
+    }();
 
     const unsigned TypeSize = DL.getTypeAllocSize(ETy);
     assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
            "type size mismatch");
 
-    const bool PassAsArray = IsByVal || shouldPassAsArray(Arg.Ty);
-    if (IsVAArg) {
-      if (ArgI == FirstVAArg) {
-        VADeclareParam = Chain =
-            DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
-                        {Chain, GetI32(STI.getMaxRequiredAlignment()),
-                         GetI32(ArgI), GetI32(1), InGlue});
+    const std::optional<SDValue> ArgDeclare = [&]() -> std::optional<SDValue> {
+      if (IsVAArg) {
+        if (ArgI == FirstVAArg) {
+          VADeclareParam = DAG.getNode(
+              NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
+              {Chain, ParamSymbol, GetI32(STI.getMaxRequiredAlignment()),
+               GetI32(0), InGlue});
+          return VADeclareParam;
+        }
+        return std::nullopt;
+      }
+      if (IsByVal || shouldPassAsArray(Arg.Ty)) {
+        // declare .param .align <align> .b8 .param<n>[<size>];
+        return DAG.getNode(NVPTXISD::DeclareArrayParam, dl,
+                           {MVT::Other, MVT::Glue},
+                           {Chain, ParamSymbol, GetI32(ArgAlign.value()),
+                            GetI32(TypeSize), InGlue});
       }
-    } else if (PassAsArray) {
-      // declare .param .align <align> .b8 .param<n>[<size>];
-      Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
-                          {Chain, GetI32(ArgAlign.value()), GetI32(ArgI),
-                           GetI32(TypeSize), InGlue});
-    } else {
       assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
       // declare .param .b<size> .param<n>;
 
@@ -1568,11 +1584,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
               ? promoteScalarArgumentSize(TypeSize * 8)
               : TypeSize * 8;
 
-      Chain =
-          DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
-                      {Chain, GetI32(ArgI), GetI32(PromotedSize), InGlue});
+      return DAG.getNode(NVPTXISD::DeclareScalarParam, dl,
+                         {MVT::Other, MVT::Glue},
+                         {Chain, ParamSymbol, GetI32(PromotedSize), InGlue});
+    }();
+    if (ArgDeclare) {
+      Chain = ArgDeclare->getValue(0);
+      InGlue = ArgDeclare->getValue(1);
     }
-    InGlue = Chain.getValue(1);
 
     // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
     // than 32-bits are sign extended or zero extended, depending on
@@ -1594,8 +1613,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       } else {
         StVal = ArgOutVals[I];
 
-        if (auto PromotedVT = PromoteScalarIntegerPTX(StVal.getValueType())) {
-          StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, *PromotedVT,
+        auto PromotedVT = promoteScalarIntegerPTX(StVal.getValueType());
+        if (PromotedVT != StVal.getValueType()) {
+          StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, PromotedVT,
                               StVal);
         }
       }
@@ -1619,12 +1639,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     unsigned J = 0;
     for (const unsigned NumElts : VectorInfo) {
       const int CurOffset = Offsets[J];
-      EVT EltVT = VTs[J];
+      EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
       const Align PartAlign = commonAlignment(ArgAlign, CurOffset);
 
-      if (auto PromotedVT = PromoteScalarIntegerPTX(EltVT))
-        EltVT = *PromotedVT;
-
       // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
       // scalar store. In such cases, fall back to byte stores.
       if (NumElts == 1 && !IsVAArg && PartAlign < DAG.getEVTAlign(EltVT)) {
@@ -1695,27 +1712,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   }
 
   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
-  MaybeAlign RetAlign = std::nullopt;
 
   // Handle Result
   if (!Ins.empty()) {
-    RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
-
-    // Declare
-    //  .param .align N .b8 retval0[<size-in-bytes>], or
-    //  .param .b<size-in-bits> retval0
-    const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
-    if (!shouldPassAsArray(RetTy)) {
-      const unsigned PromotedResultSize = promoteScalarArgumentSize(ResultSize);
-      Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, {MVT::Other, MVT::Glue},
-                          {Chain, GetI32(PromotedResultSize), InGlue});
-      InGlue = Chain.getValue(1);
-    } else {
-      Chain = DAG.getNode(
-          NVPTXISD::DeclareRetParam, dl, {MVT::Other, MVT::Glue},
-          {Chain, GetI32(RetAlign->value()), GetI32(ResultSize / 8), InGlue});
-      InGlue = Chain.getValue(1);
-    }
+    const SDValue RetDeclare = [&]() {
+      const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
+      const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
+      if (shouldPassAsArray(RetTy)) {
+        const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+        return DAG.getNode(NVPTXISD::DeclareArrayParam, dl,
+                           {MVT::Other, MVT::Glue},
+                           {Chain, RetSymbol, GetI32(RetAlign.value()),
+                            GetI32(ResultSize / 8), InGlue});
+      }
+      const auto PromotedResultSize = promoteScalarArgumentSize(ResultSize);
+      return DAG.getNode(
+          NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+          {Chain, RetSymbol, GetI32(PromotedResultSize), InGlue});
+    }();
+    Chain = RetDeclare.getValue(0);
+    InGlue = RetDeclare.getValue(1);
   }
 
   const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
@@ -1760,7 +1776,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // The prototype is embedded in a string and put as the operand for a
     // CallPrototype SDNode which will print out to the value of the string.
     std::string Proto =
-        getPrototype(DL, RetTy, Args, CLI.Outs, RetAlign,
+        getPrototype(DL, RetTy, Args, CLI.Outs,
                      HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB,
                      UniqueCallSite);
     const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
@@ -1773,11 +1789,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   if (ConvertToIndirectCall) {
     // Copy the function ptr to a ptx register and use the register to call the
     // function.
-    EVT DestVT = Callee.getValueType();
-    MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
+    const MVT DestVT = Callee.getValueType().getSimpleVT();
+    MachineRegisterInfo &MRI = DAG.getMachineFunction().getRegInfo();
     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-    unsigned DestReg =
-        RegInfo.createVirtualRegister(TLI.getRegClassFor(DestVT.getSimpleVT()));
+    Register DestReg = MRI.createVirtualRegister(TLI.getRegClassFor(DestVT));
     auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
     Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
   }
@@ -1810,7 +1825,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
     assert(VTs.size() == Ins.size() && "Bad value decomposition");
 
-    assert(RetAlign && "RetAlign is guaranteed to be set");
+    const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
 
     // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
     // 32-bits are sign extended or zero extended, depending on whether
@@ -1818,17 +1833,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const bool ExtendIntegerRetVal =
         RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
-    const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, *RetAlign);
+    const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
     unsigned I = 0;
     for (const unsigned VectorizedSize : VectorInfo) {
-      EVT TheLoadType = VTs[I];
+      EVT TheLoadType = promoteScalarIntegerPTX(VTs[I]);
       EVT EltType = Ins[I].VT;
-      const Align EltAlign = commonAlignment(*RetAlign, Offsets[I]);
+      const Align EltAlign = commonAlignment(RetAlign, Offsets[I]);
 
-      if (auto PromotedVT = PromoteScalarIntegerPTX(TheLoadType)) {
-        TheLoadType = *PromotedVT;
-        EltType = *PromotedVT;
-      }
+      if (TheLoadType != VTs[I])
+        EltType = TheLoadType;
 
       if (ExtendIntegerRetVal) {
         TheLoadType = MVT::i32;
@@ -1898,13 +1911,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       continue;
     }
 
-    SDValue Ret = DAG.getNode(
-        NVPTXISD::ProxyReg, dl,
-        {ProxyRegOps[I].getSimpleValueType(), MVT::Other, MVT::Glue},
-        {Chain, ProxyRegOps[I], InGlue});
-
-    Chain = Ret.getValue(1);
-    InGlue = Ret.getValue(2);
+    SDValue Ret =
+        DAG.getNode(NVPTXISD::ProxyReg, dl, ProxyRegOps[I].getSimpleValueType(),
+                    {Chain, ProxyRegOps[I]});
 
     const EVT ExpectedVT = Ins[I].VT;
     if (!Ret.getValueType().bitsEq(ExpectedVT)) {
@@ -1914,14 +1923,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   }
 
   for (SDValue &T : TempProxyRegOps) {
-    SDValue Repl = DAG.getNode(NVPTXISD::ProxyReg, dl,
-                               {T.getSimpleValueType(), MVT::Other, MVT::Glue},
-                               {Chain, T.getOperand(0), InGlue});
+    SDValue Repl = DAG.getNode(NVPTXISD::ProxyReg, dl, T.getSimpleValueType(),
+                               {Chain, T.getOperand(0)});
     DAG.ReplaceAllUsesWith(T, Repl);
     DAG.RemoveDeadNode(T.getNode());
-
-    Chain = Repl.getValue(1);
-    InGlue = Repl.getValue(2);
   }
 
   // set isTailCall to false for now, until we figure out how to express
@@ -3292,11 +3297,17 @@ bool NVPTXTargetLowering::splitValueIntoRegisterParts(
 // Name of the symbol is composed from its index and the function name.
 // Negative index corresponds to special parameter (unsized array) used for
 // passing variable arguments.
-SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
-                                            EVT v) const {
+SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int I,
+                                            EVT T) const {
   StringRef SavedStr = nvTM->getStrPool().save(
-      getParamName(&DAG.getMachineFunction().getFunction(), idx));
-  return DAG.getExternalSymbol(SavedStr.data(), v);
+      getParamName(&DAG.getMachineFunction().getFunction(), I));
+  return DAG.getExternalSymbol(SavedStr.data(), T);
+}
+
+SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, int I,
+                                                EVT T) const {
+  const StringRef SavedStr = nvTM->getStrPool().save("param" + Twine(I));
+  return DAG.getExternalSymbol(SavedStr.data(), T);
 }
 
 SDValue NVPTXTargetLowering::LowerFormalArguments(
@@ -3393,8 +3404,11 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         const unsigned PackingAmt =
             LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
 
-        const EVT VecVT = EVT::getVectorVT(
-            F->getContext(), LoadVT.getScalarType(), NumElts * PackingAmt);
+        const EVT VecVT =
+            NumElts == 1
+                ? LoadVT
+                : EVT::getVectorVT(F->getContext(), LoadVT.getScalarType(),
+                                   NumElts * PackingAmt);
 
         SDValue VecAddr = DAG.getObjectPtrOffset(
             dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
@@ -3408,22 +3422,16 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         if (P....
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/146411


More information about the llvm-commits mailing list