[llvm] [NVPTX] Cleanup ld/st lowering (PR #143936)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 12 10:14:26 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

<details>
<summary>Changes</summary>



---

Patch is 42.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143936.diff


9 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+174-276) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+2-1) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (-4) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+39-55) 
- (modified) llvm/test/CodeGen/NVPTX/bug26185-2.ll (+19-3) 
- (modified) llvm/test/CodeGen/NVPTX/bug26185.ll (+61-12) 
- (modified) llvm/test/CodeGen/NVPTX/i1-ext-load.ll (+1-3) 
- (modified) llvm/test/CodeGen/NVPTX/ldu-ldg.ll (+2-6) 
- (modified) llvm/test/CodeGen/NVPTX/variadics-backend.ll (+8-11) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 32223bf3d601e..5ffb0dccca4ee 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -136,7 +136,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
     break;
   case NVPTXISD::LDUV2:
   case NVPTXISD::LDUV4:
-    if (tryLDGLDU(N))
+    if (tryLDU(N))
       return;
     break;
   case NVPTXISD::StoreV2:
@@ -324,7 +324,7 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
   case Intrinsic::nvvm_ldu_global_f:
   case Intrinsic::nvvm_ldu_global_i:
   case Intrinsic::nvvm_ldu_global_p:
-    return tryLDGLDU(N);
+    return tryLDU(N);
 
   case Intrinsic::nvvm_tcgen05_ld_16x64b_x1:
   case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
@@ -1048,35 +1048,28 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
   assert(LD->readMem() && "Expected load");
 
   // do not support pre/post inc/dec
-  LoadSDNode *PlainLoad = dyn_cast<LoadSDNode>(N);
+  const LoadSDNode *PlainLoad = dyn_cast<LoadSDNode>(LD);
   if (PlainLoad && PlainLoad->isIndexed())
     return false;
 
-  EVT LoadedVT = LD->getMemoryVT();
-  if (!LoadedVT.isSimple())
+  const EVT LoadedEVT = LD->getMemoryVT();
+  if (!LoadedEVT.isSimple())
     return false;
+  const MVT LoadedVT = LoadedEVT.getSimpleVT();
 
   // Address Space Setting
   const unsigned CodeAddrSpace = getCodeAddrSpace(LD);
   if (canLowerToLDG(*LD, *Subtarget, CodeAddrSpace))
-    return tryLDGLDU(N);
+    return tryLDG(LD);
 
-  SDLoc DL(N);
+  SDLoc DL(LD);
   SDValue Chain = N->getOperand(0);
-  auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, LD);
+  const auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, LD);
 
-  // Type Setting: fromType + fromTypeWidth
-  //
-  // Sign   : ISD::SEXTLOAD
-  // Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
-  //          type is integer
-  // Float  : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
-  MVT SimpleVT = LoadedVT.getSimpleVT();
-  // Read at least 8 bits (predicates are stored as 8-bit values)
-  unsigned FromTypeWidth = std::max(8U, (unsigned)SimpleVT.getSizeInBits());
+  const unsigned FromTypeWidth = LoadedVT.getSizeInBits();
 
   // Vector Setting
-  unsigned int FromType =
+  const unsigned FromType =
       (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
           ? NVPTX::PTXLdStInstCode::Signed
           : NVPTX::PTXLdStInstCode::Untyped;
@@ -1102,29 +1095,17 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
   if (!Opcode)
     return false;
 
-  SDNode *NVPTXLD =
-      CurDAG->getMachineNode(*Opcode, DL, TargetVT, MVT::Other, Ops);
+  SDNode *NVPTXLD = CurDAG->getMachineNode(*Opcode, DL, LD->getVTList(), Ops);
   if (!NVPTXLD)
     return false;
 
-  MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
+  MachineMemOperand *MemRef = LD->getMemOperand();
   CurDAG->setNodeMemRefs(cast<MachineSDNode>(NVPTXLD), {MemRef});
 
-  ReplaceNode(N, NVPTXLD);
+  ReplaceNode(LD, NVPTXLD);
   return true;
 }
 
-static bool isSubVectorPackedInI32(EVT EltVT) {
-  // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
-  // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
-  // vectorized loads/stores with the actual element type for i8/i16 as that
-  // would require v8/v16 variants that do not exist.
-  // In order to load/store such vectors efficiently, in Type Legalization
-  // we split the vector into word-sized chunks (v2x16/v4i8). Now, we will
-  // lower to PTX as vectors of b32.
-  return Isv2x16VT(EltVT) || EltVT == MVT::v4i8;
-}
-
 static unsigned getLoadStoreVectorNumElts(SDNode *N) {
   switch (N->getOpcode()) {
   case NVPTXISD::LoadV2:
@@ -1142,21 +1123,21 @@ static unsigned getLoadStoreVectorNumElts(SDNode *N) {
 }
 
 bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
-  MemSDNode *MemSD = cast<MemSDNode>(N);
-  const EVT MemEVT = MemSD->getMemoryVT();
+  MemSDNode *LD = cast<MemSDNode>(N);
+  const EVT MemEVT = LD->getMemoryVT();
   if (!MemEVT.isSimple())
     return false;
   const MVT MemVT = MemEVT.getSimpleVT();
 
   // Address Space Setting
-  const unsigned CodeAddrSpace = getCodeAddrSpace(MemSD);
-  if (canLowerToLDG(*MemSD, *Subtarget, CodeAddrSpace))
-    return tryLDGLDU(N);
+  const unsigned CodeAddrSpace = getCodeAddrSpace(LD);
+  if (canLowerToLDG(*LD, *Subtarget, CodeAddrSpace))
+    return tryLDG(LD);
 
-  EVT EltVT = N->getValueType(0);
-  SDLoc DL(N);
-  SDValue Chain = N->getOperand(0);
-  auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, MemSD);
+  const MVT EltVT = LD->getSimpleValueType(0);
+  SDLoc DL(LD);
+  SDValue Chain = LD->getChain();
+  const auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, LD);
 
   // Type Setting: fromType + fromTypeWidth
   //
@@ -1167,18 +1148,15 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
   // Read at least 8 bits (predicates are stored as 8-bit values)
   // The last operand holds the original LoadSDNode::getExtensionType() value
   const unsigned TotalWidth = MemVT.getSizeInBits();
-  unsigned ExtensionType = N->getConstantOperandVal(N->getNumOperands() - 1);
-  unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
-                          ? NVPTX::PTXLdStInstCode::Signed
-                          : NVPTX::PTXLdStInstCode::Untyped;
+  const unsigned ExtensionType =
+      N->getConstantOperandVal(N->getNumOperands() - 1);
+  const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
+                                ? NVPTX::PTXLdStInstCode::Signed
+                                : NVPTX::PTXLdStInstCode::Untyped;
 
-  unsigned FromTypeWidth = TotalWidth / getLoadStoreVectorNumElts(N);
-
-  if (isSubVectorPackedInI32(EltVT)) {
-    assert(ExtensionType == ISD::NON_EXTLOAD);
-    EltVT = MVT::i32;
-  }
+  const unsigned FromTypeWidth = TotalWidth / getLoadStoreVectorNumElts(N);
 
+  assert(!(EltVT.isVector() && ExtensionType != ISD::NON_EXTLOAD));
   assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
          FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
 
@@ -1196,192 +1174,183 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
   std::optional<unsigned> Opcode;
   switch (N->getOpcode()) {
   default:
-    return false;
+    llvm_unreachable("Unexpected opcode");
   case NVPTXISD::LoadV2:
-    Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v2,
-                             NVPTX::LDV_i16_v2, NVPTX::LDV_i32_v2,
-                             NVPTX::LDV_i64_v2);
+    Opcode =
+        pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i8_v2, NVPTX::LDV_i16_v2,
+                        NVPTX::LDV_i32_v2, NVPTX::LDV_i64_v2);
     break;
   case NVPTXISD::LoadV4:
-    Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
-                             NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4,
-                             NVPTX::LDV_i64_v4);
+    Opcode =
+        pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i8_v4, NVPTX::LDV_i16_v4,
+                        NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4);
     break;
   case NVPTXISD::LoadV8:
-    Opcode =
-        pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, {/* no v8i8 */},
-                        {/* no v8i16 */}, NVPTX::LDV_i32_v8, {/* no v8i64 */});
+    Opcode = pickOpcodeForVT(EltVT.SimpleTy, {/* no v8i8 */}, {/* no v8i16 */},
+                             NVPTX::LDV_i32_v8, {/* no v8i64 */});
     break;
   }
   if (!Opcode)
     return false;
 
-  SDNode *LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
+  SDNode *NVPTXLD = CurDAG->getMachineNode(*Opcode, DL, LD->getVTList(), Ops);
 
-  MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
-  CurDAG->setNodeMemRefs(cast<MachineSDNode>(LD), {MemRef});
+  MachineMemOperand *MemRef = LD->getMemOperand();
+  CurDAG->setNodeMemRefs(cast<MachineSDNode>(NVPTXLD), {MemRef});
 
-  ReplaceNode(N, LD);
+  ReplaceNode(LD, NVPTXLD);
   return true;
 }
 
-bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
-  auto *Mem = cast<MemSDNode>(N);
-
-  // If this is an LDG intrinsic, the address is the third operand. If its an
-  // LDG/LDU SD node (from custom vector handling), then its the second operand
-  SDValue Op1 = N->getOperand(N->getOpcode() == ISD::INTRINSIC_W_CHAIN ? 2 : 1);
+bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
+  const EVT LoadedEVT = LD->getMemoryVT();
+  if (!LoadedEVT.isSimple())
+    return false;
+  const MVT LoadedVT = LoadedEVT.getSimpleVT();
 
-  const EVT OrigType = N->getValueType(0);
-  EVT EltVT = Mem->getMemoryVT();
-  unsigned NumElts = 1;
+  SDLoc DL(LD);
 
-  if (EltVT == MVT::i128 || EltVT == MVT::f128) {
-    EltVT = MVT::i64;
-    NumElts = 2;
-  }
-  if (EltVT.isVector()) {
-    NumElts = EltVT.getVectorNumElements();
-    EltVT = EltVT.getVectorElementType();
-    // vectors of 8/16bits type are loaded/stored as multiples of v4i8/v2x16
-    // elements.
-    if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
-        (EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
-        (EltVT == MVT::i16 && OrigType == MVT::v2i16) ||
-        (EltVT == MVT::i8 && OrigType == MVT::v4i8)) {
-      assert(NumElts % OrigType.getVectorNumElements() == 0 &&
-             "NumElts must be divisible by the number of elts in subvectors");
-      EltVT = OrigType;
-      NumElts /= OrigType.getVectorNumElements();
-    }
+  const unsigned TotalWidth = LoadedVT.getSizeInBits();
+  unsigned ExtensionType;
+  unsigned NumElts;
+  if (const auto *Load = dyn_cast<LoadSDNode>(LD)) {
+    ExtensionType = Load->getExtensionType();
+    NumElts = 1;
+  } else {
+    ExtensionType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
+    NumElts = getLoadStoreVectorNumElts(LD);
   }
+  const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
+                                ? NVPTX::PTXLdStInstCode::Signed
+                                : NVPTX::PTXLdStInstCode::Untyped;
 
-  // Build the "promoted" result VTList for the load. If we are really loading
-  // i8s, then the return type will be promoted to i16 since we do not expose
-  // 8-bit registers in NVPTX.
-  const EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
-  SmallVector<EVT, 5> InstVTs;
-  InstVTs.append(NumElts, NodeVT);
-  InstVTs.push_back(MVT::Other);
-  SDVTList InstVTList = CurDAG->getVTList(InstVTs);
-  SDValue Chain = N->getOperand(0);
+  const unsigned FromTypeWidth = TotalWidth / NumElts;
+
+  assert(!(LD->getSimpleValueType(0).isVector() &&
+           ExtensionType != ISD::NON_EXTLOAD));
+  assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
+         FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
 
   SDValue Base, Offset;
-  SelectADDR(Op1, Base, Offset);
-  SDValue Ops[] = {Base, Offset, Chain};
+  SelectADDR(LD->getOperand(1), Base, Offset);
+  SDValue Ops[] = {getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), Base,
+                   Offset, LD->getChain()};
 
+  const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
   std::optional<unsigned> Opcode;
-  switch (N->getOpcode()) {
+  switch (LD->getOpcode()) {
   default:
-    return false;
+    llvm_unreachable("Unexpected opcode");
   case ISD::LOAD:
-    Opcode = pickOpcodeForVT(
-        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
-        NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
-        NVPTX::INT_PTX_LDG_GLOBAL_i64);
-    break;
-  case ISD::INTRINSIC_W_CHAIN:
-    Opcode = pickOpcodeForVT(
-        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
-        NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
-        NVPTX::INT_PTX_LDU_GLOBAL_i64);
+    Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i8,
+                             NVPTX::LD_GLOBAL_NC_i16, NVPTX::LD_GLOBAL_NC_i32,
+                             NVPTX::LD_GLOBAL_NC_i64);
     break;
   case NVPTXISD::LoadV2:
     Opcode = pickOpcodeForVT(
-        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
-        NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
-        NVPTX::INT_PTX_LDG_G_v2i64_ELE);
-    break;
-  case NVPTXISD::LDUV2:
-    Opcode = pickOpcodeForVT(
-        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
-        NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
-        NVPTX::INT_PTX_LDU_G_v2i64_ELE);
+        TargetVT, NVPTX::LD_GLOBAL_NC_v2i8, NVPTX::LD_GLOBAL_NC_v2i16,
+        NVPTX::LD_GLOBAL_NC_v2i32, NVPTX::LD_GLOBAL_NC_v2i64);
     break;
   case NVPTXISD::LoadV4:
     Opcode = pickOpcodeForVT(
-        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
-        NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
-        NVPTX::INT_PTX_LDG_G_v4i64_ELE);
-    break;
-  case NVPTXISD::LDUV4:
-    Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                             NVPTX::INT_PTX_LDU_G_v4i8_ELE,
-                             NVPTX::INT_PTX_LDU_G_v4i16_ELE,
-                             NVPTX::INT_PTX_LDU_G_v4i32_ELE, {/* no v4i64 */});
+        TargetVT, NVPTX::LD_GLOBAL_NC_v4i8, NVPTX::LD_GLOBAL_NC_v4i16,
+        NVPTX::LD_GLOBAL_NC_v4i32, NVPTX::LD_GLOBAL_NC_v4i64);
     break;
   case NVPTXISD::LoadV8:
-    Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, {/* no v8i8 */},
-                             {/* no v8i16 */}, NVPTX::INT_PTX_LDG_G_v8i32_ELE,
-                             {/* no v8i64 */});
+    Opcode = pickOpcodeForVT(TargetVT, {/* no v8i8 */}, {/* no v8i16 */},
+                             NVPTX::LD_GLOBAL_NC_v8i32, {/* no v8i64 */});
     break;
   }
   if (!Opcode)
     return false;
 
-  SDLoc DL(N);
-  SDNode *LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
+  SDNode *NVPTXLDG = CurDAG->getMachineNode(*Opcode, DL, LD->getVTList(), Ops);
 
-  // For automatic generation of LDG (through SelectLoad[Vector], not the
-  // intrinsics), we may have an extending load like:
-  //
-  //   i32,ch = load<LD1[%data1(addrspace=1)], zext from i8> t0, t7, undef:i64
-  //
-  // In this case, the matching logic above will select a load for the original
-  // memory type (in this case, i8) and our types will not match (the node needs
-  // to return an i32 in this case). Our LDG/LDU nodes do not support the
-  // concept of sign-/zero-extension, so emulate it here by adding an explicit
-  // CVT instruction. Ptxas should clean up any redundancies here.
-
-  LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);
-
-  if (OrigType != EltVT &&
-      (LdNode || (OrigType.isFloatingPoint() && EltVT.isFloatingPoint()))) {
-    // We have an extending-load. The instruction we selected operates on the
-    // smaller type, but the SDNode we are replacing has the larger type. We
-    // need to emit a CVT to make the types match.
-    unsigned CvtOpc =
-        GetConvertOpcode(OrigType.getSimpleVT(), EltVT.getSimpleVT(), LdNode);
-
-    // For each output value, apply the manual sign/zero-extension and make sure
-    // all users of the load go through that CVT.
-    for (unsigned i = 0; i != NumElts; ++i) {
-      SDValue Res(LD, i);
-      SDValue OrigVal(N, i);
-
-      SDNode *CvtNode =
-        CurDAG->getMachineNode(CvtOpc, DL, OrigType, Res,
-                               CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE,
-                                                         DL, MVT::i32));
-      ReplaceUses(OrigVal, SDValue(CvtNode, 0));
-    }
+  ReplaceNode(LD, NVPTXLDG);
+  return true;
+}
+
+bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) {
+  auto *LD = cast<MemSDNode>(N);
+
+  unsigned NumElts;
+  switch (N->getOpcode()) {
+  default:
+    llvm_unreachable("Unexpected opcode");
+  case ISD::INTRINSIC_W_CHAIN:
+    NumElts = 1;
+    break;
+  case NVPTXISD::LDUV2:
+    NumElts = 2;
+    break;
+  case NVPTXISD::LDUV4:
+    NumElts = 4;
+    break;
   }
 
-  ReplaceNode(N, LD);
+  const MVT::SimpleValueType SelectVT =
+      MVT::getIntegerVT(LD->getMemoryVT().getSizeInBits() / NumElts).SimpleTy;
+
+  // If this is an LDU intrinsic, the address is the third operand. If its an
+  // LDU SD node (from custom vector handling), then its the second operand
+  SDValue Addr =
+      LD->getOperand(LD->getOpcode() == ISD::INTRINSIC_W_CHAIN ? 2 : 1);
+
+  SDValue Base, Offset;
+  SelectADDR(Addr, Base, Offset);
+  SDValue Ops[] = {Base, Offset, LD->getChain()};
+
+  std::optional<unsigned> Opcode;
+  switch (N->getOpcode()) {
+  default:
+    llvm_unreachable("Unexpected opcode");
+  case ISD::INTRINSIC_W_CHAIN:
+    Opcode =
+        pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_i8, NVPTX::LDU_GLOBAL_i16,
+                        NVPTX::LDU_GLOBAL_i32, NVPTX::LDU_GLOBAL_i64);
+    break;
+  case NVPTXISD::LDUV2:
+    Opcode = pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_v2i8,
+                             NVPTX::LDU_GLOBAL_v2i16, NVPTX::LDU_GLOBAL_v2i32,
+                             NVPTX::LDU_GLOBAL_v2i64);
+    break;
+  case NVPTXISD::LDUV4:
+    Opcode = pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_v4i8,
+                             NVPTX::LDU_GLOBAL_v4i16, NVPTX::LDU_GLOBAL_v4i32,
+                             {/* no v4i64 */});
+    break;
+  }
+  if (!Opcode)
+    return false;
+
+  SDLoc DL(N);
+  SDNode *NVPTXLDU = CurDAG->getMachineNode(*Opcode, DL, LD->getVTList(), Ops);
+
+  ReplaceNode(LD, NVPTXLDU);
   return true;
 }
 
 bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
   MemSDNode *ST = cast<MemSDNode>(N);
   assert(ST->writeMem() && "Expected store");
-  StoreSDNode *PlainStore = dyn_cast<StoreSDNode>(N);
-  AtomicSDNode *AtomicStore = dyn_cast<AtomicSDNode>(N);
+  StoreSDNode *PlainStore = dyn_cast<StoreSDNode>(ST);
+  AtomicSDNode *AtomicStore = dyn_cast<AtomicSDNode>(ST);
   assert((PlainStore || AtomicStore) && "Expected store");
 
   // do not support pre/post inc/dec
   if (PlainStore && PlainStore->isIndexed())
     return false;
 
-  EVT StoreVT = ST->getMemoryVT();
+  const EVT StoreVT = ST->getMemoryVT();
   if (!StoreVT.isSimple())
     return false;
 
   // Address Space Setting
-  unsigned int CodeAddrSpace = getCodeAddrSpace(ST);
+  const unsigned CodeAddrSpace = getCodeAddrSpace(ST);
 
-  SDLoc DL(N);
+  SDLoc DL(ST);
   SDValue Chain = ST->getChain();
-  auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST);
+  const auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST);
 
   // Vector Setting
   const unsigned ToTypeWidth = StoreVT.getSimpleVT().getSizeInBits();
@@ -1417,85 +1386,78 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
   if (!NVPTXST)
     return false;
 
-  MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
+  MachineMemOperand *MemRef = ST->getMemOperand();
   CurDAG->setNodeMemRefs(cast<MachineSDNode>(NVPTXST), {MemRef});
-  ReplaceNode(N, NVPTXST);
+  ReplaceNode(ST, NVPTXST);
   return true;
 }
 
 bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
-  SDValue Op1 = N->getOperand(1);
-  EVT EltVT = Op1.getValueType();
-  MemSDNode *MemSD = cast<MemSDNode>(N);
-  EVT StoreVT = MemSD->getMemoryVT();
+  MemSDNode *ST = cast<MemSDNode>(N);
+  const EVT StoreVT = ST->getMemoryVT();
   assert(StoreVT.isSimple() && "Store value is not simple");
 
   // Address Space Setting
-  unsigned CodeAddrSpace = getCodeAddrSpace(MemSD);
+  const unsigned CodeAddrSpace = getCodeAddrSpace(ST);
   if (CodeAddrSpace == NVPTX::AddressSpace::Const) {
     report_fatal_error("Cannot store to pointer that points to constant "
                        "memory space");
   }
 
-  SDLoc DL(N);
-  SDValue Chain = N->getOperand(0);
-  auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, MemSD);
+  SDLoc DL(ST);
+  SDValue Chain = ST->getChain();
+  const auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST);
 
   // Type Setting: toType + toTypeWidth
   // - for integer type, always use 'u'
   const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
 
-  unsigned NumElts = getLoadStoreVectorNumElts(N);
-
-  SmallVector<SDValue, 16> Ops(N->ops().slice(1, NumElts));
-  SDValue N2 = N->getOperand(NumElts + 1);
-  unsigned ToTypeWidth = TotalWidth / NumElts;
+  const unsigned NumElts = getLoadStoreVectorNumElts(ST);
 
-  if (isSubVectorPackedInI32(EltVT)) {
-    EltVT = MVT::i32;
-  }
+  SmallVector<SDValue, 16> Ops(ST->ops().slice(1, NumElts));
+  SDValue Addr = N->getOperand(NumElts + 1);
+  const unsign...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/143936


More information about the llvm-commits mailing list