[llvm] [NVPTX] Fix v2i8 call lowering, use generic ld/st nodes for call params (PR #146930)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 3 10:29:13 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

<details>
<summary>Changes</summary>



---

Patch is 241.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146930.diff


37 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (-273) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (-2) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+227-364) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+2-8) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+3-130) 
- (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll (+3-3) 
- (modified) llvm/test/CodeGen/NVPTX/byval-const-global.ll (+4-4) 
- (modified) llvm/test/CodeGen/NVPTX/call-with-alloca-buffer.ll (+5-5) 
- (modified) llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll (+8-8) 
- (modified) llvm/test/CodeGen/NVPTX/combine-mad.ll (+2-2) 
- (modified) llvm/test/CodeGen/NVPTX/compare-int.ll (+501-120) 
- (modified) llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll (+162-10) 
- (modified) llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll (+2-2) 
- (modified) llvm/test/CodeGen/NVPTX/f16x2-instructions.ll (+6-6) 
- (modified) llvm/test/CodeGen/NVPTX/fma.ll (+4-4) 
- (modified) llvm/test/CodeGen/NVPTX/forward-ld-param.ll (+1-1) 
- (modified) llvm/test/CodeGen/NVPTX/i128-param.ll (+10-10) 
- (modified) llvm/test/CodeGen/NVPTX/i16x2-instructions.ll (+6-6) 
- (modified) llvm/test/CodeGen/NVPTX/i8x2-instructions.ll (+93-28) 
- (modified) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+6-6) 
- (modified) llvm/test/CodeGen/NVPTX/idioms.ll (+1-1) 
- (modified) llvm/test/CodeGen/NVPTX/indirect_byval.ll (+10-10) 
- (modified) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+14-19) 
- (modified) llvm/test/CodeGen/NVPTX/lower-args.ll (+4-6) 
- (modified) llvm/test/CodeGen/NVPTX/lower-byval-args.ll (+1-1) 
- (modified) llvm/test/CodeGen/NVPTX/misched_func_call.ll (+7-8) 
- (modified) llvm/test/CodeGen/NVPTX/param-add.ll (+4-4) 
- (modified) llvm/test/CodeGen/NVPTX/param-load-store.ll (+124-134) 
- (modified) llvm/test/CodeGen/NVPTX/param-overalign.ll (+2-2) 
- (modified) llvm/test/CodeGen/NVPTX/param-vectorize-device.ll (+14-14) 
- (modified) llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir (+2-2) 
- (modified) llvm/test/CodeGen/NVPTX/st-param-imm.ll (+146-75) 
- (modified) llvm/test/CodeGen/NVPTX/store-undef.ll (+2-2) 
- (modified) llvm/test/CodeGen/NVPTX/tex-read-cuda.ll (+1-1) 
- (modified) llvm/test/CodeGen/NVPTX/unaligned-param-load-store.ll (+242-352) 
- (modified) llvm/test/CodeGen/NVPTX/vaargs.ll (+8-8) 
- (modified) llvm/test/CodeGen/NVPTX/variadics-backend.ll (+15-17) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 5631342ecc13e..fdd2671a5289e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -145,18 +145,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
     if (tryStoreVector(N))
       return;
     break;
-  case NVPTXISD::LoadParam:
-  case NVPTXISD::LoadParamV2:
-  case NVPTXISD::LoadParamV4:
-    if (tryLoadParam(N))
-      return;
-    break;
-  case NVPTXISD::StoreParam:
-  case NVPTXISD::StoreParamV2:
-  case NVPTXISD::StoreParamV4:
-    if (tryStoreParam(N))
-      return;
-    break;
   case ISD::INTRINSIC_W_CHAIN:
     if (tryIntrinsicChain(N))
       return;
@@ -1429,267 +1417,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
   return true;
 }
 
-bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
-  SDValue Chain = Node->getOperand(0);
-  SDValue Offset = Node->getOperand(2);
-  SDValue Glue = Node->getOperand(3);
-  SDLoc DL(Node);
-  MemSDNode *Mem = cast<MemSDNode>(Node);
-
-  unsigned VecSize;
-  switch (Node->getOpcode()) {
-  default:
-    return false;
-  case NVPTXISD::LoadParam:
-    VecSize = 1;
-    break;
-  case NVPTXISD::LoadParamV2:
-    VecSize = 2;
-    break;
-  case NVPTXISD::LoadParamV4:
-    VecSize = 4;
-    break;
-  }
-
-  EVT EltVT = Node->getValueType(0);
-  EVT MemVT = Mem->getMemoryVT();
-
-  std::optional<unsigned> Opcode;
-
-  switch (VecSize) {
-  default:
-    return false;
-  case 1:
-    Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
-                             NVPTX::LoadParamMemI8, NVPTX::LoadParamMemI16,
-                             NVPTX::LoadParamMemI32, NVPTX::LoadParamMemI64);
-    break;
-  case 2:
-    Opcode =
-        pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV2I8,
-                        NVPTX::LoadParamMemV2I16, NVPTX::LoadParamMemV2I32,
-                        NVPTX::LoadParamMemV2I64);
-    break;
-  case 4:
-    Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
-                             NVPTX::LoadParamMemV4I8, NVPTX::LoadParamMemV4I16,
-                             NVPTX::LoadParamMemV4I32, {/* no v4i64 */});
-    break;
-  }
-  if (!Opcode)
-    return false;
-
-  SDVTList VTs;
-  if (VecSize == 1) {
-    VTs = CurDAG->getVTList(EltVT, MVT::Other, MVT::Glue);
-  } else if (VecSize == 2) {
-    VTs = CurDAG->getVTList(EltVT, EltVT, MVT::Other, MVT::Glue);
-  } else {
-    EVT EVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other, MVT::Glue };
-    VTs = CurDAG->getVTList(EVTs);
-  }
-
-  unsigned OffsetVal = Offset->getAsZExtVal();
-
-  SmallVector<SDValue, 2> Ops(
-      {CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
-
-  ReplaceNode(Node, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops));
-  return true;
-}
-
-// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
-#define getOpcV2H(ty, opKind0, opKind1)                                        \
-  NVPTX::StoreParamV2##ty##_##opKind0##opKind1
-
-#define getOpcV2H1(ty, opKind0, isImm1)                                        \
-  (isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)
-
-#define getOpcodeForVectorStParamV2(ty, isimm)                                 \
-  (isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])
-
-#define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3)                      \
-  NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3
-
-#define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3)                      \
-  (isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i)                       \
-           : getOpcV4H(ty, opKind0, opKind1, opKind2, r)
-
-#define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3)                       \
-  (isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3)                       \
-           : getOpcV4H3(ty, opKind0, opKind1, r, isImm3)
-
-#define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3)                        \
-  (isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3)                        \
-           : getOpcV4H2(ty, opKind0, r, isImm2, isImm3)
-
-#define getOpcodeForVectorStParamV4(ty, isimm)                                 \
-  (isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3])                 \
-             : getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])
-
-#define getOpcodeForVectorStParam(n, ty, isimm)                                \
-  (n == 2) ? getOpcodeForVectorStParamV2(ty, isimm)                            \
-           : getOpcodeForVectorStParamV4(ty, isimm)
-
-static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops,
-                                           unsigned NumElts,
-                                           MVT::SimpleValueType MemTy,
-                                           SelectionDAG *CurDAG, SDLoc DL) {
-  // Determine which inputs are registers and immediates make new operators
-  // with constant values
-  SmallVector<bool, 4> IsImm(NumElts, false);
-  for (unsigned i = 0; i < NumElts; i++) {
-    IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
-    if (IsImm[i]) {
-      SDValue Imm = Ops[i];
-      if (MemTy == MVT::f32 || MemTy == MVT::f64) {
-        const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
-        const ConstantFP *CF = ConstImm->getConstantFPValue();
-        Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
-      } else {
-        const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
-        const ConstantInt *CI = ConstImm->getConstantIntValue();
-        Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
-      }
-      Ops[i] = Imm;
-    }
-  }
-
-  // Get opcode for MemTy, size, and register/immediate operand ordering
-  switch (MemTy) {
-  case MVT::i8:
-    return getOpcodeForVectorStParam(NumElts, I8, IsImm);
-  case MVT::i16:
-    return getOpcodeForVectorStParam(NumElts, I16, IsImm);
-  case MVT::i32:
-    return getOpcodeForVectorStParam(NumElts, I32, IsImm);
-  case MVT::i64:
-    assert(NumElts == 2 && "MVT too large for NumElts > 2");
-    return getOpcodeForVectorStParamV2(I64, IsImm);
-  case MVT::f32:
-    return getOpcodeForVectorStParam(NumElts, F32, IsImm);
-  case MVT::f64:
-    assert(NumElts == 2 && "MVT too large for NumElts > 2");
-    return getOpcodeForVectorStParamV2(F64, IsImm);
-
-  // These cases don't support immediates, just use the all register version
-  // and generate moves.
-  case MVT::i1:
-    return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr
-                          : NVPTX::StoreParamV4I8_rrrr;
-  case MVT::f16:
-  case MVT::bf16:
-    return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr
-                          : NVPTX::StoreParamV4I16_rrrr;
-  case MVT::v2f16:
-  case MVT::v2bf16:
-  case MVT::v2i16:
-  case MVT::v4i8:
-    return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr
-                          : NVPTX::StoreParamV4I32_rrrr;
-  default:
-    llvm_unreachable("Cannot select st.param for unknown MemTy");
-  }
-}
-
-bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
-  SDLoc DL(N);
-  SDValue Chain = N->getOperand(0);
-  SDValue Param = N->getOperand(1);
-  unsigned ParamVal = Param->getAsZExtVal();
-  SDValue Offset = N->getOperand(2);
-  unsigned OffsetVal = Offset->getAsZExtVal();
-  MemSDNode *Mem = cast<MemSDNode>(N);
-  SDValue Glue = N->getOperand(N->getNumOperands() - 1);
-
-  // How many elements do we have?
-  unsigned NumElts;
-  switch (N->getOpcode()) {
-  default:
-    llvm_unreachable("Unexpected opcode");
-  case NVPTXISD::StoreParam:
-    NumElts = 1;
-    break;
-  case NVPTXISD::StoreParamV2:
-    NumElts = 2;
-    break;
-  case NVPTXISD::StoreParamV4:
-    NumElts = 4;
-    break;
-  }
-
-  // Build vector of operands
-  SmallVector<SDValue, 8> Ops;
-  for (unsigned i = 0; i < NumElts; ++i)
-    Ops.push_back(N->getOperand(i + 3));
-  Ops.append({CurDAG->getTargetConstant(ParamVal, DL, MVT::i32),
-              CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
-
-  // Determine target opcode
-  // If we have an i1, use an 8-bit store. The lowering code in
-  // NVPTXISelLowering will have already emitted an upcast.
-  std::optional<unsigned> Opcode;
-  switch (NumElts) {
-  default:
-    llvm_unreachable("Unexpected NumElts");
-  case 1: {
-    MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
-    SDValue Imm = Ops[0];
-    if (MemTy != MVT::f16 && MemTy != MVT::bf16 &&
-        (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
-      // Convert immediate to target constant
-      if (MemTy == MVT::f32 || MemTy == MVT::f64) {
-        const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
-        const ConstantFP *CF = ConstImm->getConstantFPValue();
-        Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
-      } else {
-        const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
-        const ConstantInt *CI = ConstImm->getConstantIntValue();
-        Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
-      }
-      Ops[0] = Imm;
-      // Use immediate version of store param
-      Opcode =
-          pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i, NVPTX::StoreParamI16_i,
-                          NVPTX::StoreParamI32_i, NVPTX::StoreParamI64_i);
-    } else
-      Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
-                               NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
-                               NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r);
-    if (Opcode == NVPTX::StoreParamI8_r) {
-      // Fine tune the opcode depending on the size of the operand.
-      // This helps to avoid creating redundant COPY instructions in
-      // InstrEmitter::AddRegisterOperand().
-      switch (Ops[0].getSimpleValueType().SimpleTy) {
-      default:
-        break;
-      case MVT::i32:
-        Opcode = NVPTX::StoreParamI8TruncI32_r;
-        break;
-      case MVT::i64:
-        Opcode = NVPTX::StoreParamI8TruncI64_r;
-        break;
-      }
-    }
-    break;
-  }
-  case 2:
-  case 4: {
-    MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
-    Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
-    break;
-  }
-  }
-
-  SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
-  SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, RetVTs, Ops);
-  MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
-  CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
-
-  ReplaceNode(N, Ret);
-  return true;
-}
-
 /// SelectBFE - Look for instruction sequences that can be made more efficient
 /// by using the 'bfe' (bit-field extract) PTX instruction
 bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 0e4dec1adca67..19b569d638b0c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -78,8 +78,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool tryLDG(MemSDNode *N);
   bool tryStore(SDNode *N);
   bool tryStoreVector(SDNode *N);
-  bool tryLoadParam(SDNode *N);
-  bool tryStoreParam(SDNode *N);
   bool tryFence(SDNode *N);
   void SelectAddrSpaceCast(SDNode *N);
   bool tryBFE(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index bb0aeb493ed48..3cda7df55ad58 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1049,12 +1049,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::DeclareArrayParam)
     MAKE_CASE(NVPTXISD::DeclareScalarParam)
     MAKE_CASE(NVPTXISD::CALL)
-    MAKE_CASE(NVPTXISD::LoadParam)
-    MAKE_CASE(NVPTXISD::LoadParamV2)
-    MAKE_CASE(NVPTXISD::LoadParamV4)
-    MAKE_CASE(NVPTXISD::StoreParam)
-    MAKE_CASE(NVPTXISD::StoreParamV2)
-    MAKE_CASE(NVPTXISD::StoreParamV4)
     MAKE_CASE(NVPTXISD::MoveParam)
     MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
     MAKE_CASE(NVPTXISD::BUILD_VECTOR)
@@ -1293,105 +1287,6 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
   return DL.getABITypeAlign(Ty);
 }
 
-static bool adjustElementType(EVT &ElementType) {
-  switch (ElementType.getSimpleVT().SimpleTy) {
-  default:
-    return false;
-  case MVT::f16:
-  case MVT::bf16:
-    ElementType = MVT::i16;
-    return true;
-  case MVT::f32:
-  case MVT::v2f16:
-  case MVT::v2bf16:
-    ElementType = MVT::i32;
-    return true;
-  case MVT::f64:
-    ElementType = MVT::i64;
-    return true;
-  }
-}
-
-// Use byte-store when the param address of the argument value is unaligned.
-// This may happen when the return value is a field of a packed structure.
-//
-// This is called in LowerCall() when passing the param values.
-static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain,
-                                        uint64_t Offset, EVT ElementType,
-                                        SDValue StVal, SDValue &InGlue,
-                                        unsigned ArgID, const SDLoc &dl) {
-  // Bit logic only works on integer types
-  if (adjustElementType(ElementType))
-    StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal);
-
-  // Store each byte
-  SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-  for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
-    // Shift the byte to the last byte position
-    SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal,
-                                   DAG.getConstant(i * 8, dl, MVT::i32));
-    SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32),
-                               DAG.getConstant(Offset + i, dl, MVT::i32),
-                               ShiftVal, InGlue};
-    // Trunc store only the last byte by using
-    //     st.param.b8
-    // The register type can be larger than b8.
-    Chain = DAG.getMemIntrinsicNode(
-        NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8,
-        MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
-    InGlue = Chain.getValue(1);
-  }
-  return Chain;
-}
-
-// Use byte-load when the param adress of the returned value is unaligned.
-// This may happen when the returned value is a field of a packed structure.
-static SDValue
-LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
-                           EVT ElementType, SDValue &InGlue,
-                           SmallVectorImpl<SDValue> &TempProxyRegOps,
-                           const SDLoc &dl) {
-  // Bit logic only works on integer types
-  EVT MergedType = ElementType;
-  adjustElementType(MergedType);
-
-  // Load each byte and construct the whole value. Initial value to 0
-  SDValue RetVal = DAG.getConstant(0, dl, MergedType);
-  // LoadParamMemI8 loads into i16 register only
-  SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue);
-  for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
-    SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
-                              DAG.getConstant(Offset + i, dl, MVT::i32),
-                              InGlue};
-    // This will be selected to LoadParamMemI8
-    SDValue LdVal =
-        DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands,
-                                MVT::i8, MachinePointerInfo(), Align(1));
-    SDValue TmpLdVal = LdVal.getValue(0);
-    Chain = LdVal.getValue(1);
-    InGlue = LdVal.getValue(2);
-
-    TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl,
-                           TmpLdVal.getSimpleValueType(), TmpLdVal);
-    TempProxyRegOps.push_back(TmpLdVal);
-
-    SDValue CMask = DAG.getConstant(255, dl, MergedType);
-    SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32);
-    // Need to extend the i16 register to the whole width.
-    TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal);
-    // Mask off the high bits. Leave only the lower 8bits.
-    // Do this because we are using loadparam.b8.
-    TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask);
-    // Shift and merge
-    TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift);
-    RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal);
-  }
-  if (ElementType != MergedType)
-    RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
-
-  return RetVal;
-}
-
 static bool shouldConvertToIndirectCall(const CallBase *CB,
                                         const GlobalAddressSDNode *Func) {
   if (!Func)
@@ -1458,10 +1353,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
   SelectionDAG &DAG = CLI.DAG;
   SDLoc dl = CLI.DL;
-  SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
-  SDValue Chain = CLI.Chain;
+  const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
   SDValue Callee = CLI.Callee;
-  bool &isTailCall = CLI.IsTailCall;
   ArgListTy &Args = CLI.getArgs();
   Type *RetTy = CLI.RetTy;
   const CallBase *CB = CLI.CB;
@@ -1492,9 +1385,34 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   unsigned VAOffset = 0;                  // current offset in the param array
 
   const unsigned UniqueCallSite = GlobalUniqueCallSite++;
-  SDValue TempChain = Chain;
-  Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
-  SDValue InGlue = Chain.getValue(1);
+  const SDValue CallChain = CLI.Chain;
+  const SDValue StartChain =
+      DAG.getCALLSEQ_START(CallChain, UniqueCallSite, 0, dl);
+  SDValue DeclareGlue = StartChain.getValue(1);
+
+  SmallVector<SDValue, 16> CallPrereqs{StartChain};
+
+  const auto DeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
+    // PTX ABI requires integral types to be at least 32 bits in size. FP16 is
+    // loaded/stored using i16, so it's handled here as well.
+    const unsigned SizeBits = promoteScalarArgumentSize(Size * 8);
+    SDValue Declare =
+        DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+                    {StartChain, Symbol, GetI32(SizeBits), DeclareGlue});
+    CallPrereqs.push_back(Declare);
+    DeclareGlue = Declare.getValue(1);
+    return Declare;
+  };
+
+  const auto DeclareArrayParam = [&](SDValue Symbol, Align Align,
+                                     unsigned Size) {
+    SDValue Declare = DAG.getNode(
+        NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
+        {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
+    CallPrereqs.push_back(Declare);
+    DeclareGlue = Declare.getValue(1);
+    return Declare;
+  };
 
   // Args.size() and Outs.size() need not match.
   // Outs.size() will be larger
@@ -1555,43 +1473,23 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
            "type size mismatch");
 
-    const std::optional<SDValue> ArgDeclare = [&]() -> std::optional<SDValue> {
+    const SDValue ArgDeclare = [&]() {
       if (IsVAArg) {
-        if (ArgI == FirstVAArg) {
-          VADeclareParam = DAG.getNode(
-              NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
-              {Chain, ParamSymbol, GetI32(STI.getMaxRequiredAlignment()),
-               GetI32(0), InGlue});
-          return VADeclareParam;
-        }
-        return std::nullopt;
-      }
-      if (IsByVal || shouldPassAsArray(Arg.Ty)) {
-        // declare .param .align <align> .b8 .param<n>[<size>];
-        return DAG.getNode(NVPTXISD::DeclareArrayParam, dl,
-                           {MVT::Other, MVT::Glue},
-                           {Chain, ParamSymbol, GetI32(ArgAlign.value()),
-                            GetI32(TypeSize), InGlue});
+        if (ArgI == FirstVAArg)
+          VADeclareParam = DeclareArrayParam(
+              ParamSymbol, Align(STI.getMaxRequiredAlignment()), 0);
+        return VADeclareParam;
       }
+
+      if (IsByVal || shouldPassAsArray(Arg.Ty))
+        return DeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
+
       assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
-      // declare .param .b<size> .param<n>;
-
-      // PTX ABI requ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/146930


More information about the llvm-commits mailing list