[llvm] ea698c4 - [NVPTX][NFC] Refactoring and cleanup in NVPTXISelLowering (#137222)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 24 20:01:38 PDT 2025


Author: Alex MacLean
Date: 2025-04-24T20:01:35-07:00
New Revision: ea698c444707c8c3e10cb675003beb686fc94103

URL: https://github.com/llvm/llvm-project/commit/ea698c444707c8c3e10cb675003beb686fc94103
DIFF: https://github.com/llvm/llvm-project/commit/ea698c444707c8c3e10cb675003beb686fc94103.diff

LOG: [NVPTX][NFC] Refactoring and cleanup in NVPTXISelLowering (#137222)

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 8dd9bf2876927..c41741ed10232 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -493,13 +493,6 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
   return VectorInfo;
 }
 
-static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT,
-                            SDValue Value) {
-  if (Value->getValueType(0) == VT)
-    return Value;
-  return DAG.getNode(ISD::BITCAST, DL, VT, Value);
-}
-
 // NVPTXTargetLowering Constructor.
 NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
                                          const NVPTXSubtarget &STI)
@@ -1587,9 +1580,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
     auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg);
     SmallVector<SDValue, 6> StoreOperands;
-    for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
-      EVT EltVT = VTs[j];
-      int CurOffset = Offsets[j];
+    for (const unsigned J : llvm::seq(VTs.size())) {
+      EVT EltVT = VTs[J];
+      const int CurOffset = Offsets[J];
       MaybeAlign PartAlign;
       if (NeedAlign)
         PartAlign = commonAlignment(ArgAlign, CurOffset);
@@ -1629,7 +1622,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
       // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
       // scalar store. In such cases, fall back to byte stores.
-      if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
+      if (VectorInfo[J] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() &&
           PartAlign.value() <
               DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) {
         assert(StoreOperands.empty() && "Unfinished preceeding store.");
@@ -1645,7 +1638,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       }
 
       // New store.
-      if (VectorInfo[j] & PVF_FIRST) {
+      if (VectorInfo[J] & PVF_FIRST) {
         assert(StoreOperands.empty() && "Unfinished preceding store.");
         StoreOperands.push_back(Chain);
         StoreOperands.push_back(
@@ -1665,8 +1658,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       // Record the value to store.
       StoreOperands.push_back(StVal);
 
-      if (VectorInfo[j] & PVF_LAST) {
-        unsigned NumElts = StoreOperands.size() - 3;
+      if (VectorInfo[J] & PVF_LAST) {
+        const unsigned NumElts = StoreOperands.size() - 3;
         NVPTXISD::NodeType Op;
         switch (NumElts) {
         case 1:
@@ -2168,7 +2161,7 @@ SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
       ISD::OR, DL, MVT::i16,
       {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
   EVT ToVT = Op->getValueType(0);
-  return MaybeBitcast(DAG, DL, ToVT, AsInt);
+  return DAG.getBitcast(ToVT, AsInt);
 }
 
 // We can init constant f16x2/v2i16/v4i8 with a single .b32 move.  Normally it
@@ -3367,18 +3360,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
   auto PtrVT = getPointerTy(DAG.getDataLayout());
 
   const Function *F = &MF.getFunction();
-  const AttributeList &PAL = F->getAttributes();
-  const TargetLowering *TLI = STI.getTargetLowering();
 
   SDValue Root = DAG.getRoot();
-  std::vector<SDValue> OutChains;
+  SmallVector<SDValue, 16> OutChains;
 
-  std::vector<Type *> argTypes;
-  std::vector<const Argument *> theArgs;
-  for (const Argument &I : F->args()) {
-    theArgs.push_back(&I);
-    argTypes.push_back(I.getType());
-  }
   // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
   // Ins.size() will be larger
   //   * if there is an aggregate argument with multiple fields (each field
@@ -3388,49 +3373,59 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
   //     individually present in Ins.
   // So a 
diff erent index should be used for indexing into Ins.
   // See similar issue in LowerCall.
-  unsigned InsIdx = 0;
 
-  for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++InsIdx) {
-    Type *Ty = argTypes[i];
+  auto AllIns = ArrayRef(Ins);
+  for (const auto &Arg : F->args()) {
+    const auto ArgIns = AllIns.take_while(
+        [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
+    AllIns = AllIns.drop_front(ArgIns.size());
 
-    if (theArgs[i]->use_empty()) {
-      // argument is dead
-      if (shouldPassAsArray(Ty) && !Ty->isVectorTy()) {
-        SmallVector<EVT, 16> vtparts;
+    Type *Ty = Arg.getType();
 
-        ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
-        if (vtparts.empty())
-          report_fatal_error("Empty parameter types are not supported");
+    if (ArgIns.empty())
+      report_fatal_error("Empty parameter types are not supported");
 
-        for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
-             ++parti) {
-          InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
-          ++InsIdx;
-        }
-        if (vtparts.size() > 0)
-          --InsIdx;
-        continue;
-      }
-      if (Ty->isVectorTy()) {
-        EVT ObjectVT = getValueType(DL, Ty);
-        unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
-        for (unsigned parti = 0; parti < NumRegs; ++parti) {
-          InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
-          ++InsIdx;
-        }
-        if (NumRegs > 0)
-          --InsIdx;
-        continue;
+    if (Arg.use_empty()) {
+      // argument is dead
+      for (const auto &In : ArgIns) {
+        assert(!In.Used && "Arg.use_empty() is true but Arg is used?");
+        InVals.push_back(DAG.getUNDEF(In.VT));
       }
-      InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
       continue;
     }
 
+    SDValue ArgSymbol = getParamSymbol(DAG, Arg.getArgNo(), PtrVT);
+
     // In the following cases, assign a node order of "i+1"
     // to newly created nodes. The SDNodes for params have to
     // appear in the same order as their order of appearance
     // in the original function. "i+1" holds that order.
-    if (!PAL.hasParamAttr(i, Attribute::ByVal)) {
+    if (Arg.hasByValAttr()) {
+      // Param has ByVal attribute
+      // Return MoveParam(param symbol).
+      // Ideally, the param symbol can be returned directly,
+      // but when SDNode builder decides to use it in a CopyToReg(),
+      // machine instruction fails because TargetExternalSymbol
+      // (not lowered) is target dependent, and CopyToReg assumes
+      // the source is lowered.
+      assert(ArgIns.size() == 1 && "ByVal argument must be a pointer");
+      const auto &ByvalIn = ArgIns[0];
+      assert(getValueType(DL, Ty) == ByvalIn.VT &&
+             "Ins type did not match function type");
+      assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
+
+      SDValue P;
+      if (isKernelFunction(*F)) {
+        P = DAG.getNode(NVPTXISD::Wrapper, dl, ByvalIn.VT, ArgSymbol);
+        P.getNode()->setIROrder(Arg.getArgNo() + 1);
+      } else {
+        P = DAG.getNode(NVPTXISD::MoveParam, dl, ByvalIn.VT, ArgSymbol);
+        P.getNode()->setIROrder(Arg.getArgNo() + 1);
+        P = DAG.getAddrSpaceCast(dl, ByvalIn.VT, P, ADDRESS_SPACE_LOCAL,
+                                 ADDRESS_SPACE_GENERIC);
+      }
+      InVals.push_back(P);
+    } else {
       bool aggregateIsPacked = false;
       if (StructType *STy = dyn_cast<StructType>(Ty))
         aggregateIsPacked = STy->isPacked();
@@ -3438,25 +3433,25 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
       SmallVector<EVT, 16> VTs;
       SmallVector<uint64_t, 16> Offsets;
       ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
-      if (VTs.empty())
-        report_fatal_error("Empty parameter types are not supported");
+      assert(VTs.size() == ArgIns.size() && "Size mismatch");
+      assert(VTs.size() == Offsets.size() && "Size mismatch");
 
       Align ArgAlign = getFunctionArgumentAlignment(
-          F, Ty, i + AttributeList::FirstArgIndex, DL);
+          F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
       auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
+      assert(VectorInfo.size() == VTs.size() && "Size mismatch");
 
-      SDValue Arg = getParamSymbol(DAG, i, PtrVT);
       int VecIdx = -1; // Index of the first element of the current vector.
-      for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) {
-        if (VectorInfo[parti] & PVF_FIRST) {
+      for (const unsigned PartI : llvm::seq(VTs.size())) {
+        if (VectorInfo[PartI] & PVF_FIRST) {
           assert(VecIdx == -1 && "Orphaned vector.");
-          VecIdx = parti;
+          VecIdx = PartI;
         }
 
         // That's the last element of this store op.
-        if (VectorInfo[parti] & PVF_LAST) {
-          unsigned NumElts = parti - VecIdx + 1;
-          EVT EltVT = VTs[parti];
+        if (VectorInfo[PartI] & PVF_LAST) {
+          const unsigned NumElts = PartI - VecIdx + 1;
+          EVT EltVT = VTs[PartI];
           // i1 is loaded/stored as i8.
           EVT LoadVT = EltVT;
           if (EltVT == MVT::i1)
@@ -3469,10 +3464,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
 
           EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
           SDValue VecAddr =
-              DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
+              DAG.getNode(ISD::ADD, dl, PtrVT, ArgSymbol,
                           DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
-          Value *srcValue = Constant::getNullValue(
-              PointerType::get(F->getContext(), ADDRESS_SPACE_PARAM));
 
           const MaybeAlign PartAlign = [&]() -> MaybeAlign {
             if (aggregateIsPacked)
@@ -3481,23 +3474,23 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
               return std::nullopt;
             Align PartAlign =
                 DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext()));
-            return commonAlignment(PartAlign, Offsets[parti]);
+            return commonAlignment(PartAlign, Offsets[PartI]);
           }();
-          SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
-                                  MachinePointerInfo(srcValue), PartAlign,
-                                  MachineMemOperand::MODereferenceable |
-                                      MachineMemOperand::MOInvariant);
+          SDValue P =
+              DAG.getLoad(VecVT, dl, Root, VecAddr,
+                          MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
+                          MachineMemOperand::MODereferenceable |
+                              MachineMemOperand::MOInvariant);
           if (P.getNode())
-            P.getNode()->setIROrder(i + 1);
-          for (unsigned j = 0; j < NumElts; ++j) {
+            P.getNode()->setIROrder(Arg.getArgNo() + 1);
+          for (const unsigned J : llvm::seq(NumElts)) {
             SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
-                                      DAG.getIntPtrConstant(j, dl));
+                                      DAG.getIntPtrConstant(J, dl));
             // We've loaded i1 as an i8 and now must truncate it back to i1
             if (EltVT == MVT::i1)
               Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
             // v2f16 was loaded as an i32. Now we must bitcast it back.
-            else if (EltVT != LoadVT)
-              Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
+            Elt = DAG.getBitcast(EltVT, Elt);
 
             // If a promoted integer type is used, truncate down to the original
             MVT PromotedVT;
@@ -3507,12 +3500,12 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
 
             // Extend the element if necessary (e.g. an i8 is loaded
             // into an i16 register)
-            if (Ins[InsIdx].VT.isInteger() &&
-                Ins[InsIdx].VT.getFixedSizeInBits() >
-                    LoadVT.getFixedSizeInBits()) {
-              unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
-                                                           : ISD::ZERO_EXTEND;
-              Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt);
+            if (ArgIns[PartI].VT.getFixedSizeInBits() !=
+                LoadVT.getFixedSizeInBits()) {
+              assert(ArgIns[PartI].VT.isInteger() && LoadVT.isInteger() &&
+                     "Non-integer argument type size mismatch");
+              Elt = DAG.getExtOrTrunc(ArgIns[PartI].Flags.isSExt(), Elt, dl,
+                                      ArgIns[PartI].VT);
             }
             InVals.push_back(Elt);
           }
@@ -3520,40 +3513,12 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
           // Reset vector tracking state.
           VecIdx = -1;
         }
-        ++InsIdx;
       }
-      if (VTs.size() > 0)
-        --InsIdx;
-      continue;
-    }
-
-    // Param has ByVal attribute
-    // Return MoveParam(param symbol).
-    // Ideally, the param symbol can be returned directly,
-    // but when SDNode builder decides to use it in a CopyToReg(),
-    // machine instruction fails because TargetExternalSymbol
-    // (not lowered) is target dependent, and CopyToReg assumes
-    // the source is lowered.
-    EVT ObjectVT = getValueType(DL, Ty);
-    assert(ObjectVT == Ins[InsIdx].VT &&
-           "Ins type did not match function type");
-    SDValue Arg = getParamSymbol(DAG, i, PtrVT);
-
-    SDValue P;
-    if (isKernelFunction(*F)) {
-      P = DAG.getNode(NVPTXISD::Wrapper, dl, ObjectVT, Arg);
-      P.getNode()->setIROrder(i + 1);
-    } else {
-      P = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
-      P.getNode()->setIROrder(i + 1);
-      P = DAG.getAddrSpaceCast(dl, ObjectVT, P, ADDRESS_SPACE_LOCAL,
-                               ADDRESS_SPACE_GENERIC);
     }
-    InVals.push_back(P);
   }
 
   if (!OutChains.empty())
-    DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains));
+    DAG.setRoot(DAG.getTokenFactor(dl, OutChains));
 
   return Chain;
 }
@@ -5784,7 +5749,7 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
 
   // Bitcast to i16 and unpack elements into a vector
   SDLoc DL(Node);
-  SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0));
+  SDValue AsInt = DAG.getBitcast(MVT::i16, Op->getOperand(0));
   SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
   SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
   SDValue Vec1 =


        


More information about the llvm-commits mailing list