[llvm] [NVPTX][NFC] Refactor and cleanup NVPTXISelLowering call lowering 2/n (PR #137666)

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 28 09:40:30 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

<details>
<summary>Changes</summary>



---

Patch is 60.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/137666.diff


3 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+418-547) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+5-5) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+5-28) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index c41741ed10232..b21635f7caf04 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -343,33 +343,35 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
 /// and promote them to a larger size if they're not.
 ///
 /// The promoted type is placed in \p PromoteVT if the function returns true.
-static bool PromoteScalarIntegerPTX(const EVT &VT, MVT *PromotedVT) {
+static std::optional<MVT> PromoteScalarIntegerPTX(const EVT &VT) {
   if (VT.isScalarInteger()) {
+    MVT PromotedVT;
     switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
     default:
       llvm_unreachable(
           "Promotion is not suitable for scalars of size larger than 64-bits");
     case 1:
-      *PromotedVT = MVT::i1;
+      PromotedVT = MVT::i1;
       break;
     case 2:
     case 4:
     case 8:
-      *PromotedVT = MVT::i8;
+      PromotedVT = MVT::i8;
       break;
     case 16:
-      *PromotedVT = MVT::i16;
+      PromotedVT = MVT::i16;
       break;
     case 32:
-      *PromotedVT = MVT::i32;
+      PromotedVT = MVT::i32;
       break;
     case 64:
-      *PromotedVT = MVT::i64;
+      PromotedVT = MVT::i64;
       break;
     }
-    return EVT(*PromotedVT) != VT;
+    if (VT != PromotedVT)
+      return PromotedVT;
   }
-  return false;
+  return std::nullopt;
 }
 
 // Check whether we can merge loads/stores of some of the pieces of a
@@ -426,16 +428,6 @@ static unsigned CanMergeParamLoadStoresStartingAt(
   return NumElts;
 }
 
-// Flags for tracking per-element vectorization state of loads/stores
-// of a flattened function parameter or return value.
-enum ParamVectorizationFlags {
-  PVF_INNER = 0x0, // Middle elements of a vector.
-  PVF_FIRST = 0x1, // First element of the vector.
-  PVF_LAST = 0x2,  // Last element of the vector.
-  // Scalar is effectively a 1-element vector.
-  PVF_SCALAR = PVF_FIRST | PVF_LAST
-};
-
 // Computes whether and how we can vectorize the loads/stores of a
 // flattened function parameter or return value.
 //
@@ -444,52 +436,39 @@ enum ParamVectorizationFlags {
 // of the same size as ValueVTs indicating how each piece should be
 // loaded/stored (i.e. as a scalar, or as part of a vector
 // load/store).
-static SmallVector<ParamVectorizationFlags, 16>
+static SmallVector<unsigned, 16>
 VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
                      const SmallVectorImpl<uint64_t> &Offsets,
                      Align ParamAlignment, bool IsVAArg = false) {
   // Set vector size to match ValueVTs and mark all elements as
   // scalars by default.
-  SmallVector<ParamVectorizationFlags, 16> VectorInfo;
-  VectorInfo.assign(ValueVTs.size(), PVF_SCALAR);
+  SmallVector<unsigned, 16> VectorInfo;
 
-  if (IsVAArg)
+  if (IsVAArg) {
+    VectorInfo.assign(ValueVTs.size(), 1);
     return VectorInfo;
+  }
 
-  // Check what we can vectorize using 128/64/32-bit accesses.
-  for (int I = 0, E = ValueVTs.size(); I != E; ++I) {
-    // Skip elements we've already processed.
-    assert(VectorInfo[I] == PVF_SCALAR && "Unexpected vector info state.");
-    for (unsigned AccessSize : {16, 8, 4, 2}) {
-      unsigned NumElts = CanMergeParamLoadStoresStartingAt(
+  const auto GetNumElts = [&](unsigned I) -> unsigned {
+    for (const unsigned AccessSize : {16, 8, 4, 2}) {
+      const unsigned NumElts = CanMergeParamLoadStoresStartingAt(
           I, AccessSize, ValueVTs, Offsets, ParamAlignment);
-      // Mark vectorized elements.
-      switch (NumElts) {
-      default:
-        llvm_unreachable("Unexpected return value");
-      case 1:
-        // Can't vectorize using this size, try next smaller size.
-        continue;
-      case 2:
-        assert(I + 1 < E && "Not enough elements.");
-        VectorInfo[I] = PVF_FIRST;
-        VectorInfo[I + 1] = PVF_LAST;
-        I += 1;
-        break;
-      case 4:
-        assert(I + 3 < E && "Not enough elements.");
-        VectorInfo[I] = PVF_FIRST;
-        VectorInfo[I + 1] = PVF_INNER;
-        VectorInfo[I + 2] = PVF_INNER;
-        VectorInfo[I + 3] = PVF_LAST;
-        I += 3;
-        break;
-      }
-      // Break out of the inner loop because we've already succeeded
-      // using largest possible AccessSize.
-      break;
+      assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
+             "Unexpected vectorization size");
+      if (NumElts != 1)
+        return NumElts;
     }
+    return 1;
+  };
+
+  // Check what we can vectorize using 128/64/32-bit accesses.
+  for (unsigned I = 0, E = ValueVTs.size(); I != E;) {
+    const unsigned NumElts = GetNumElts(I);
+    VectorInfo.push_back(NumElts);
+    I += NumElts;
   }
+  assert(std::accumulate(VectorInfo.begin(), VectorInfo.end(), 0u) ==
+         ValueVTs.size());
   return VectorInfo;
 }
 
@@ -1165,21 +1144,24 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
 
 std::string NVPTXTargetLowering::getPrototype(
     const DataLayout &DL, Type *retTy, const ArgListTy &Args,
-    const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
-    std::optional<std::pair<unsigned, const APInt &>> VAInfo,
-    const CallBase &CB, unsigned UniqueCallSite) const {
+    const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign RetAlign,
+    std::optional<std::pair<unsigned, unsigned>> VAInfo, const CallBase &CB,
+    unsigned UniqueCallSite) const {
   auto PtrVT = getPointerTy(DL);
 
   std::string Prototype;
   raw_string_ostream O(Prototype);
   O << "prototype_" << UniqueCallSite << " : .callprototype ";
 
-  if (retTy->getTypeID() == Type::VoidTyID) {
+  if (retTy->isVoidTy()) {
     O << "()";
   } else {
     O << "(";
-    if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
-        !shouldPassAsArray(retTy)) {
+    if (shouldPassAsArray(retTy)) {
+      assert(RetAlign && "RetAlign must be set for non-void return types");
+      O << ".param .align " << RetAlign->value() << " .b8 _["
+        << DL.getTypeAllocSize(retTy) << "]";
+    } else if (retTy->isFloatingPointTy() || retTy->isIntegerTy()) {
       unsigned size = 0;
       if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
         size = ITy->getBitWidth();
@@ -1196,9 +1178,6 @@ std::string NVPTXTargetLowering::getPrototype(
       O << ".param .b" << size << " _";
     } else if (isa<PointerType>(retTy)) {
       O << ".param .b" << PtrVT.getSizeInBits() << " _";
-    } else if (shouldPassAsArray(retTy)) {
-      O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
-        << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
     } else {
       llvm_unreachable("Unknown return type");
     }
@@ -1208,57 +1187,52 @@ std::string NVPTXTargetLowering::getPrototype(
 
   bool first = true;
 
-  unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
-  for (unsigned i = 0, OIdx = 0; i != NumArgs; ++i, ++OIdx) {
-    Type *Ty = Args[i].Ty;
+  const unsigned NumArgs = VAInfo ? VAInfo->first : Args.size();
+  auto AllOuts = ArrayRef(Outs);
+  for (const unsigned I : llvm::seq(NumArgs)) {
+    const auto ArgOuts =
+        AllOuts.take_while([I](auto O) { return O.OrigArgIndex == I; });
+    AllOuts = AllOuts.drop_front(ArgOuts.size());
+
+    Type *Ty = Args[I].Ty;
     if (!first) {
       O << ", ";
     }
     first = false;
 
-    if (!Outs[OIdx].Flags.isByVal()) {
+    if (ArgOuts[0].Flags.isByVal()) {
+      // Indirect calls need strict ABI alignment so we disable optimizations by
+      // not providing a function to optimize.
+      Type *ETy = Args[I].IndirectType;
+      Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
+      Align ParamByValAlign =
+          getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
+
+      O << ".param .align " << ParamByValAlign.value() << " .b8 _["
+        << ArgOuts[0].Flags.getByValSize() << "]";
+    } else {
       if (shouldPassAsArray(Ty)) {
         Align ParamAlign =
-            getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
-        O << ".param .align " << ParamAlign.value() << " .b8 ";
-        O << "_";
-        O << "[" << DL.getTypeAllocSize(Ty) << "]";
-        // update the index for Outs
-        SmallVector<EVT, 16> vtparts;
-        ComputeValueVTs(*this, DL, Ty, vtparts);
-        if (unsigned len = vtparts.size())
-          OIdx += len - 1;
+            getArgumentAlignment(&CB, Ty, I + AttributeList::FirstArgIndex, DL);
+        O << ".param .align " << ParamAlign.value() << " .b8 _["
+          << DL.getTypeAllocSize(Ty) << "]";
         continue;
       }
       // i8 types in IR will be i16 types in SDAG
-      assert((getValueType(DL, Ty) == Outs[OIdx].VT ||
-              (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
+      assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
+              (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
              "type mismatch between callee prototype and arguments");
       // scalar type
       unsigned sz = 0;
-      if (isa<IntegerType>(Ty)) {
-        sz = cast<IntegerType>(Ty)->getBitWidth();
-        sz = promoteScalarArgumentSize(sz);
+      if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
+        sz = promoteScalarArgumentSize(ITy->getBitWidth());
       } else if (isa<PointerType>(Ty)) {
         sz = PtrVT.getSizeInBits();
       } else {
         sz = Ty->getPrimitiveSizeInBits();
       }
-      O << ".param .b" << sz << " ";
-      O << "_";
-      continue;
+      O << ".param .b" << sz << " _";
     }
-
-    // Indirect calls need strict ABI alignment so we disable optimizations by
-    // not providing a function to optimize.
-    Type *ETy = Args[i].IndirectType;
-    Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
-    Align ParamByValAlign =
-        getFunctionByValParamAlign(/*F=*/nullptr, ETy, InitialAlign, DL);
-
-    O << ".param .align " << ParamByValAlign.value() << " .b8 ";
-    O << "_";
-    O << "[" << Outs[OIdx].Flags.getByValSize() << "]";
   }
 
   if (VAInfo)
@@ -1441,6 +1415,10 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
   return MachinePointerInfo();
 }
 
+static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
+  return Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+}
+
 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                                        SmallVectorImpl<SDValue> &InVals) const {
 
@@ -1451,8 +1429,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
   SelectionDAG &DAG = CLI.DAG;
   SDLoc dl = CLI.DL;
-  SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
-  SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
   SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
   SDValue Chain = CLI.Chain;
   SDValue Callee = CLI.Callee;
@@ -1462,6 +1438,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   const CallBase *CB = CLI.CB;
   const DataLayout &DL = DAG.getDataLayout();
 
+  const auto GetI32 = [&](const unsigned I) {
+    return DAG.getConstant(I, dl, MVT::i32);
+  };
+
   // Variadic arguments.
   //
   // Normally, for each argument, we declare a param scalar or a param
@@ -1479,7 +1459,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   // vararg byte array.
 
   SDValue VADeclareParam;                 // vararg byte array
-  unsigned FirstVAArg = CLI.NumFixedArgs; // position of the first variadic
+  const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
   unsigned VAOffset = 0;                  // current offset in the param array
 
   const unsigned UniqueCallSite = GlobalUniqueCallSite++;
@@ -1487,7 +1467,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
   SDValue InGlue = Chain.getValue(1);
 
-  unsigned ParamCount = 0;
   // Args.size() and Outs.size() need not match.
   // Outs.size() will be larger
   //   * if there is an aggregate argument with multiple fields (each field
@@ -1497,77 +1476,78 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   //     individually present in Outs.
   // So a different index should be used for indexing into Outs/OutVals.
   // See similar issue in LowerFormalArguments.
-  unsigned OIdx = 0;
+  auto AllOuts = ArrayRef(CLI.Outs);
+  auto AllOutVals = ArrayRef(CLI.OutVals);
+  assert(AllOuts.size() == AllOutVals.size() &&
+         "Outs and OutVals must be the same size");
   // Declare the .params or .reg need to pass values
   // to the function
-  for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
-    EVT VT = Outs[OIdx].VT;
-    Type *Ty = Args[i].Ty;
-    bool IsVAArg = (i >= CLI.NumFixedArgs);
-    bool IsByVal = Outs[OIdx].Flags.isByVal();
+  for (const auto [ArgI, Arg] : llvm::enumerate(Args)) {
+    const auto ArgOuts = AllOuts.take_while(
+        [ArgI = ArgI](auto O) { return O.OrigArgIndex == ArgI; });
+    const auto ArgOutVals = AllOutVals.take_front(ArgOuts.size());
+    AllOuts = AllOuts.drop_front(ArgOuts.size());
+    AllOutVals = AllOutVals.drop_front(ArgOuts.size());
+
+    const bool IsVAArg = (ArgI >= FirstVAArg);
+    const bool IsByVal = Arg.IsByVal;
 
     SmallVector<EVT, 16> VTs;
     SmallVector<uint64_t, 16> Offsets;
 
-    assert((!IsByVal || Args[i].IndirectType) &&
+    assert((!IsByVal || Arg.IndirectType) &&
            "byval arg must have indirect type");
-    Type *ETy = (IsByVal ? Args[i].IndirectType : Ty);
+    Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
     ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
+    assert(VTs.size() == Offsets.size() && "Size mismatch");
+    assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
 
     Align ArgAlign;
     if (IsByVal) {
       // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
       // so we don't need to worry whether it's naturally aligned or not.
       // See TargetLowering::LowerCallTo().
-      Align InitialAlign = Outs[OIdx].Flags.getNonZeroByValAlign();
+      Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
       ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
                                             InitialAlign, DL);
       if (IsVAArg)
         VAOffset = alignTo(VAOffset, ArgAlign);
     } else {
-      ArgAlign = getArgumentAlignment(CB, Ty, ParamCount + 1, DL);
+      ArgAlign = getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
     }
 
-    unsigned TypeSize =
-        (IsByVal ? Outs[OIdx].Flags.getByValSize() : DL.getTypeAllocSize(Ty));
-    SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+    const unsigned TypeSize = DL.getTypeAllocSize(ETy);
+    assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
+           "type size mismatch");
 
-    bool NeedAlign; // Does argument declaration specify alignment?
-    const bool PassAsArray = IsByVal || shouldPassAsArray(Ty);
+    const bool PassAsArray = IsByVal || shouldPassAsArray(Arg.Ty);
     if (IsVAArg) {
-      if (ParamCount == FirstVAArg) {
-        SDValue DeclareParamOps[] = {
-            Chain, DAG.getConstant(STI.getMaxRequiredAlignment(), dl, MVT::i32),
-            DAG.getConstant(ParamCount, dl, MVT::i32),
-            DAG.getConstant(1, dl, MVT::i32), InGlue};
-        VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl,
-                                             DeclareParamVTs, DeclareParamOps);
+      if (ArgI == FirstVAArg) {
+        VADeclareParam = Chain =
+            DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
+                        {Chain, GetI32(STI.getMaxRequiredAlignment()),
+                         GetI32(ArgI), GetI32(1), InGlue});
       }
-      NeedAlign = PassAsArray;
     } else if (PassAsArray) {
       // declare .param .align <align> .b8 .param<n>[<size>];
-      SDValue DeclareParamOps[] = {
-          Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
-          DAG.getConstant(ParamCount, dl, MVT::i32),
-          DAG.getConstant(TypeSize, dl, MVT::i32), InGlue};
-      Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
-                          DeclareParamOps);
-      NeedAlign = true;
+      Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
+                          {Chain, GetI32(ArgAlign.value()), GetI32(ArgI),
+                           GetI32(TypeSize), InGlue});
     } else {
+      assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
       // declare .param .b<size> .param<n>;
-      if (VT.isInteger() || VT.isFloatingPoint()) {
-        // PTX ABI requires integral types to be at least 32 bits in
-        // size. FP16 is loaded/stored using i16, so it's handled
-        // here as well.
-        TypeSize = promoteScalarArgumentSize(TypeSize * 8) / 8;
-      }
-      SDValue DeclareScalarParamOps[] = {
-          Chain, DAG.getConstant(ParamCount, dl, MVT::i32),
-          DAG.getConstant(TypeSize * 8, dl, MVT::i32),
-          DAG.getConstant(0, dl, MVT::i32), InGlue};
-      Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
-                          DeclareScalarParamOps);
-      NeedAlign = false;
+
+      // PTX ABI requires integral types to be at least 32 bits in
+      // size. FP16 is loaded/stored using i16, so it's handled
+      // here as well.
+      const unsigned PromotedSize =
+          (ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint())
+              ? promoteScalarArgumentSize(TypeSize * 8)
+              : TypeSize * 8;
+
+      Chain = DAG.getNode(
+          NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+          {Chain, GetI32(ArgI), GetI32(PromotedSize), GetI32(0), InGlue});
     }
     InGlue = Chain.getValue(1);
 
@@ -1575,196 +1555,169 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // than 32-bits are sign extended or zero extended, depending on
     // whether they are signed or unsigned types. This case applies
     // only to scalar parameters and not to aggregate values.
-    bool ExtendIntegerParam =
-        Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
+    const bool ExtendIntegerParam =
+        Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
 
-    auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
-    SmallVector<SDValue, 6> StoreOperands;
-    for (const unsigned J : llvm::seq(VTs.size())) {
-      EVT EltVT = VTs[J];
-      const int CurOffset = Offsets[J];
-      MaybeAlign PartAlign;
-      if (NeedAlign)
-        PartAlign = commonAlignment(ArgAlign, CurOffset);
+    const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
+                                    const Align PartAlign) {
+      SDValue StVal;
+      if (IsByVal) {
+        SDValue Ptr = ArgOutVals[0];
+        auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
+        SDValue SrcAddr =
+            DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I]));
 
-      SDValue StVal = OutVals[OIdx];
+        StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
+      } else {
+        StVal = ArgOutVals[I];
 
-      MVT PromotedVT;
-      if (PromoteScalarIntegerPTX(EltVT, &PromotedVT)) {
-        EltVT = EVT(PromotedVT);
-      }
-      if (PromoteScalarIntegerPTX(StVal.getValueType(), &PromotedVT)) {
-        llvm::ISD::NodeType Ext =
-            Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
-        StVal = DAG.getNode(Ext, dl, PromotedVT, StVal);
+        if (auto PromotedVT = PromoteScalarIntegerPTX(StVal.getValueType())) {
+          StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, *PromotedVT,
+                              StVal);
+        }
       }
 
-      if (IsByVal) {
-        auto MPI = refinePtrAS(StVal, DAG, DL, *this);
-        const EVT PtrVT = StVal.getValueType();
-        SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
-                                      DAG.getConstant(CurOffset, dl, PtrVT));
-
-        StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
-      } else if (ExtendIntegerParam) {
+      if (ExtendIntegerParam) {
   ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/137666


More information about the llvm-commits mailing list