[llvm] [NVPTX] Vectorize loads when lowering of byval calls, misc. cleanup (PR #151070)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 29 16:46:11 PDT 2025


================
@@ -1476,139 +1524,144 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const SDValue ParamSymbol =
         getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
 
-    SmallVector<EVT, 16> VTs;
-    SmallVector<uint64_t, 16> Offsets;
-
     assert((!IsByVal || Arg.IndirectType) &&
            "byval arg must have indirect type");
     Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
-    ComputePTXValueVTs(*this, DL, ETy, VTs, &Offsets, IsByVal ? 0 : VAOffset);
-    assert(VTs.size() == Offsets.size() && "Size mismatch");
-    assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
 
     const Align ArgAlign = [&]() {
       if (IsByVal) {
         // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
         // so we don't need to worry whether it's naturally aligned or not.
         // See TargetLowering::LowerCallTo().
         const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
-        const Align ByValAlign = getFunctionByValParamAlign(
-            CB->getCalledFunction(), ETy, InitialAlign, DL);
-        if (IsVAArg)
-          VAOffset = alignTo(VAOffset, ByValAlign);
-        return ByValAlign;
+        return getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
+                                          InitialAlign, DL);
       }
       return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
     }();
 
-    const unsigned TypeSize = DL.getTypeAllocSize(ETy);
-    assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
+    const unsigned TySize = DL.getTypeAllocSize(ETy);
+    assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
            "type size mismatch");
 
     const SDValue ArgDeclare = [&]() {
       if (IsVAArg)
         return VADeclareParam;
 
       if (IsByVal || shouldPassAsArray(Arg.Ty))
-        return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
+        return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
 
       assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
       assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
              "Only int and float types are supported as non-array arguments");
 
-      return MakeDeclareScalarParam(ParamSymbol, TypeSize);
+      return MakeDeclareScalarParam(ParamSymbol, TySize);
     }();
 
-    // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
-    // than 32-bits are sign extended or zero extended, depending on
-    // whether they are signed or unsigned types. This case applies
-    // only to scalar parameters and not to aggregate values.
-    const bool ExtendIntegerParam =
-        Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
+    if (IsByVal) {
+      assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
+      SDValue SrcPtr = ArgOutVals[0];
+      const auto PointerInfo = refinePtrAS(SrcPtr, DAG, DL, *this);
+      const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
 
-    const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
-                                    const MaybeAlign PartAlign) {
-      if (IsByVal) {
-        SDValue Ptr = ArgOutVals[0];
-        auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
-        SDValue SrcAddr =
-            DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I]));
-
-        return DAG.getLoad(EltVT, dl, CallChain, SrcAddr, MPI, PartAlign);
+      if (IsVAArg)
+        VAOffset = alignTo(VAOffset, ArgAlign);
+
+      SmallVector<EVT, 4> ValueVTs, MemVTs;
+      SmallVector<TypeSize, 4> Offsets;
+      ComputeValueVTs(*this, DL, ETy, ValueVTs, &MemVTs, &Offsets);
+
+      unsigned J = 0;
+      const auto VI = VectorizePTXValueVTs(MemVTs, Offsets, ArgAlign, IsVAArg);
+      for (const unsigned NumElts : VI) {
+        EVT LoadVT = getVectorizedVT(MemVTs[J], NumElts, Ctx);
+        Align SrcAlign = commonAlignment(BaseSrcAlign, Offsets[J]);
+        SDValue SrcAddr = DAG.getObjectPtrOffset(dl, SrcPtr, Offsets[J]);
+        SDValue SrcLoad =
+            DAG.getLoad(LoadVT, dl, CallChain, SrcAddr, PointerInfo, SrcAlign);
+
+        TypeSize ParamOffset = Offsets[J].getWithIncrement(VAOffset);
+        Align ParamAlign = commonAlignment(ArgAlign, ParamOffset);
+        SDValue ParamAddr =
+            DAG.getObjectPtrOffset(dl, ParamSymbol, ParamOffset);
+        SDValue StoreParam =
+            DAG.getStore(ArgDeclare, dl, SrcLoad, ParamAddr,
+                         MachinePointerInfo(ADDRESS_SPACE_PARAM), ParamAlign);
+        CallPrereqs.push_back(StoreParam);
+
+        J += NumElts;
       }
-      SDValue StVal = ArgOutVals[I];
-      assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
-                 StVal.getValueType() &&
-             "OutVal type should always be legal");
-
-      const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
-      const EVT StoreVT =
-          ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
-
-      return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
-    };
-
-    const auto VectorInfo =
-        VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
-
-    unsigned J = 0;
-    for (const unsigned NumElts : VectorInfo) {
-      const int CurOffset = Offsets[J];
-      const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
-
-      if (IsVAArg && !IsByVal)
-        // Align each part of the variadic argument to their type.
-        VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
-
-      assert((IsVAArg || VAOffset == 0) &&
-             "VAOffset must be 0 for non-VA args");
+      if (IsVAArg)
+        VAOffset += TySize;
+    } else {
+      SmallVector<EVT, 16> VTs;
+      SmallVector<uint64_t, 16> Offsets;
+      ComputePTXValueVTs(*this, DL, Arg.Ty, VTs, &Offsets, VAOffset);
+      assert(VTs.size() == Offsets.size() && "Size mismatch");
+      assert(VTs.size() == ArgOuts.size() && "Size mismatch");
+
+      // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
+      // than 32-bits are sign extended or zero extended, depending on
+      // whether they are signed or unsigned types. This case applies
+      // only to scalar parameters and not to aggregate values.
+      const bool ExtendIntegerParam =
+          Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
+
+      const auto GetStoredValue = [&](const unsigned I) {
+        SDValue StVal = ArgOutVals[I];
+        assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
+                   StVal.getValueType() &&
+               "OutVal type should always be legal");
+
+        const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
+        const EVT StoreVT =
+            ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
+
+        return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
+      };
+
+      unsigned J = 0;
+      const auto VI = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
+      for (const unsigned NumElts : VI) {
+        const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
+
+        unsigned Offset;
+        if (IsVAArg) {
+          // TODO: We may need to support vector types that can be passed
+          // as scalars in variadic arguments.
+          assert(NumElts == 1 &&
+                 "Vectorization should be disabled for vaargs.");
+
+          // Align each part of the variadic argument to their type.
+          VAOffset = alignTo(VAOffset, DAG.getEVTAlign(EltVT));
+          Offset = VAOffset;
+
+          const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
+          VAOffset += DL.getTypeAllocSize(TheStoreType.getTypeForEVT(Ctx));
+        } else {
+          assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
+          Offset = Offsets[J];
+        }
 
-      const unsigned Offset =
-          (VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset));
-      SDValue Ptr =
-          DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
+        SDValue Ptr =
+            DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
 
-      const MaybeAlign CurrentAlign = ExtendIntegerParam
-                                          ? MaybeAlign(std::nullopt)
-                                          : commonAlignment(ArgAlign, Offset);
+        const MaybeAlign CurrentAlign = ExtendIntegerParam
+                                            ? MaybeAlign(std::nullopt)
+                                            : commonAlignment(ArgAlign, Offset);
 
-      SDValue Val;
-      if (NumElts == 1) {
-        Val = GetStoredValue(J, EltVT, CurrentAlign);
-      } else {
-        SmallVector<SDValue, 8> StoreVals;
-        for (const unsigned K : llvm::seq(NumElts)) {
-          SDValue ValJ = GetStoredValue(J + K, EltVT, CurrentAlign);
-          if (ValJ.getValueType().isVector())
-            DAG.ExtractVectorElements(ValJ, StoreVals);
-          else
-            StoreVals.push_back(ValJ);
-        }
+        SDValue Val = getBuildVectorizedValue(
+            [&](unsigned K) { return GetStoredValue(J + K); }, NumElts, dl,
----------------
Artem-B wrote:

Nit: I'd use `[=]` here to indicate that we're not intending to modify the captured value. 

https://github.com/llvm/llvm-project/pull/151070


More information about the llvm-commits mailing list