[llvm] [NVPTX] Combine addressing-mode variants of ld, st, wmma (PR #129102)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 27 11:00:16 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

<details>
<summary>Changes</summary>

This change fold together the _ari, _ari64, and _asi variants of these instructions into a single instruction capable of holding any address. This allows for the removal of a lot of unnecessary code and moves us towards a standard way of representing an address in NVPTX.

---

Patch is 58.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129102.diff


5 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+157-410) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+1-11) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+51-139) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+78-128) 
- (modified) llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp (+1-1) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 971a128aadfdb..08022104bfedf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -930,8 +930,6 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
   if (canLowerToLDG(LD, *Subtarget, CodeAddrSpace, MF)) {
     return tryLDGLDU(N);
   }
-  unsigned int PointerSize =
-      CurDAG->getDataLayout().getPointerSizeInBits(LD->getAddressSpace());
 
   SDLoc DL(N);
   SDValue Chain = N->getOperand(0);
@@ -964,37 +962,24 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
     FromType = getLdStRegType(ScalarVT);
 
   // Create the machine instruction DAG
-  SDValue N1 = N->getOperand(1);
   SDValue Offset, Base;
-  std::optional<unsigned> Opcode;
-  MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
-
-  SmallVector<SDValue, 12> Ops({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
-                                getI32Imm(CodeAddrSpace, DL),
-                                getI32Imm(VecType, DL), getI32Imm(FromType, DL),
-                                getI32Imm(FromTypeWidth, DL)});
-
-  if (SelectADDRsi(N1.getNode(), N1, Base, Offset)) {
-    Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_asi, NVPTX::LD_i16_asi,
-                             NVPTX::LD_i32_asi, NVPTX::LD_i64_asi,
-                             NVPTX::LD_f32_asi, NVPTX::LD_f64_asi);
-  } else {
-    if (PointerSize == 64) {
-      SelectADDRri64(N1.getNode(), N1, Base, Offset);
-      Opcode =
-          pickOpcodeForVT(TargetVT, NVPTX::LD_i8_ari_64, NVPTX::LD_i16_ari_64,
-                          NVPTX::LD_i32_ari_64, NVPTX::LD_i64_ari_64,
-                          NVPTX::LD_f32_ari_64, NVPTX::LD_f64_ari_64);
-    } else {
-      SelectADDRri(N1.getNode(), N1, Base, Offset);
-      Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_ari, NVPTX::LD_i16_ari,
-                               NVPTX::LD_i32_ari, NVPTX::LD_i64_ari,
-                               NVPTX::LD_f32_ari, NVPTX::LD_f64_ari);
-    }
-  }
+  SelectADDR(N->getOperand(1), Base, Offset);
+  SDValue Ops[] = {getI32Imm(Ordering, DL),
+                   getI32Imm(Scope, DL),
+                   getI32Imm(CodeAddrSpace, DL),
+                   getI32Imm(VecType, DL),
+                   getI32Imm(FromType, DL),
+                   getI32Imm(FromTypeWidth, DL),
+                   Base,
+                   Offset,
+                   Chain};
+
+  const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
+  const std::optional<unsigned> Opcode =
+      pickOpcodeForVT(TargetVT, NVPTX::LD_i8, NVPTX::LD_i16, NVPTX::LD_i32,
+                      NVPTX::LD_i64, NVPTX::LD_f32, NVPTX::LD_f64);
   if (!Opcode)
     return false;
-  Ops.append({Base, Offset, Chain});
 
   SDNode *NVPTXLD =
       CurDAG->getMachineNode(*Opcode, DL, TargetVT, MVT::Other, Ops);
@@ -1030,8 +1015,6 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
   if (canLowerToLDG(MemSD, *Subtarget, CodeAddrSpace, MF)) {
     return tryLDGLDU(N);
   }
-  unsigned int PointerSize =
-      CurDAG->getDataLayout().getPointerSizeInBits(MemSD->getAddressSpace());
 
   SDLoc DL(N);
   SDValue Chain = N->getOperand(0);
@@ -1079,77 +1062,38 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
     FromTypeWidth = 32;
   }
 
-  SDValue Op1 = N->getOperand(1);
   SDValue Offset, Base;
-  std::optional<unsigned> Opcode;
-  SDNode *LD;
+  SelectADDR(N->getOperand(1), Base, Offset);
+  SDValue Ops[] = {getI32Imm(Ordering, DL),
+                   getI32Imm(Scope, DL),
+                   getI32Imm(CodeAddrSpace, DL),
+                   getI32Imm(VecType, DL),
+                   getI32Imm(FromType, DL),
+                   getI32Imm(FromTypeWidth, DL),
+                   Base,
+                   Offset,
+                   Chain};
 
-  SmallVector<SDValue, 12> Ops({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
-                                getI32Imm(CodeAddrSpace, DL),
-                                getI32Imm(VecType, DL), getI32Imm(FromType, DL),
-                                getI32Imm(FromTypeWidth, DL)});
-
-  if (SelectADDRsi(Op1.getNode(), Op1, Base, Offset)) {
-    switch (N->getOpcode()) {
-    default:
-      return false;
-    case NVPTXISD::LoadV2:
-      Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                               NVPTX::LDV_i8_v2_asi, NVPTX::LDV_i16_v2_asi,
-                               NVPTX::LDV_i32_v2_asi, NVPTX::LDV_i64_v2_asi,
-                               NVPTX::LDV_f32_v2_asi, NVPTX::LDV_f64_v2_asi);
-      break;
-    case NVPTXISD::LoadV4:
-      Opcode =
-          pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_asi,
-                          NVPTX::LDV_i16_v4_asi, NVPTX::LDV_i32_v4_asi,
-                          std::nullopt, NVPTX::LDV_f32_v4_asi, std::nullopt);
-      break;
-    }
-  } else {
-    if (PointerSize == 64) {
-      SelectADDRri64(Op1.getNode(), Op1, Base, Offset);
-      switch (N->getOpcode()) {
-      default:
-        return false;
-      case NVPTXISD::LoadV2:
-        Opcode =
-            pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                            NVPTX::LDV_i8_v2_ari_64, NVPTX::LDV_i16_v2_ari_64,
-                            NVPTX::LDV_i32_v2_ari_64, NVPTX::LDV_i64_v2_ari_64,
-                            NVPTX::LDV_f32_v2_ari_64, NVPTX::LDV_f64_v2_ari_64);
-        break;
-      case NVPTXISD::LoadV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_ari_64,
-            NVPTX::LDV_i16_v4_ari_64, NVPTX::LDV_i32_v4_ari_64, std::nullopt,
-            NVPTX::LDV_f32_v4_ari_64, std::nullopt);
-        break;
-      }
-    } else {
-      SelectADDRri(Op1.getNode(), Op1, Base, Offset);
-      switch (N->getOpcode()) {
-      default:
-        return false;
-      case NVPTXISD::LoadV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::LDV_i8_v2_ari, NVPTX::LDV_i16_v2_ari,
-                                 NVPTX::LDV_i32_v2_ari, NVPTX::LDV_i64_v2_ari,
-                                 NVPTX::LDV_f32_v2_ari, NVPTX::LDV_f64_v2_ari);
-        break;
-      case NVPTXISD::LoadV4:
-        Opcode =
-            pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_ari,
-                            NVPTX::LDV_i16_v4_ari, NVPTX::LDV_i32_v4_ari,
-                            std::nullopt, NVPTX::LDV_f32_v4_ari, std::nullopt);
-        break;
-      }
-    }
+  std::optional<unsigned> Opcode;
+  switch (N->getOpcode()) {
+  default:
+    return false;
+  case NVPTXISD::LoadV2:
+    Opcode =
+        pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v2,
+                        NVPTX::LDV_i16_v2, NVPTX::LDV_i32_v2, NVPTX::LDV_i64_v2,
+                        NVPTX::LDV_f32_v2, NVPTX::LDV_f64_v2);
+    break;
+  case NVPTXISD::LoadV4:
+    Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
+                             NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, std::nullopt,
+                             NVPTX::LDV_f32_v4, std::nullopt);
+    break;
   }
   if (!Opcode)
     return false;
-  Ops.append({Base, Offset, Chain});
-  LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
+
+  SDNode *LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
 
   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
   CurDAG->setNodeMemRefs(cast<MachineSDNode>(LD), {MemRef});
@@ -1197,176 +1141,58 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
   SDValue Chain = N->getOperand(0);
 
   std::optional<unsigned> Opcode;
-  SDLoc DL(N);
-  SDNode *LD;
-  SDValue Base, Offset;
-
-  if (SelectADDRsi(Op1.getNode(), Op1, Base, Offset)) {
-    switch (N->getOpcode()) {
-    default:
-      return false;
-    case ISD::LOAD:
-      Opcode = pickOpcodeForVT(
-          EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8asi,
-          NVPTX::INT_PTX_LDG_GLOBAL_i16asi, NVPTX::INT_PTX_LDG_GLOBAL_i32asi,
-          NVPTX::INT_PTX_LDG_GLOBAL_i64asi, NVPTX::INT_PTX_LDG_GLOBAL_f32asi,
-          NVPTX::INT_PTX_LDG_GLOBAL_f64asi);
-      break;
-    case ISD::INTRINSIC_W_CHAIN:
-      Opcode = pickOpcodeForVT(
-          EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8asi,
-          NVPTX::INT_PTX_LDU_GLOBAL_i16asi, NVPTX::INT_PTX_LDU_GLOBAL_i32asi,
-          NVPTX::INT_PTX_LDU_GLOBAL_i64asi, NVPTX::INT_PTX_LDU_GLOBAL_f32asi,
-          NVPTX::INT_PTX_LDU_GLOBAL_f64asi);
-      break;
-    case NVPTXISD::LoadV2:
-      Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                               NVPTX::INT_PTX_LDG_G_v2i8_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2i16_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2i32_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2i64_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2f32_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2f64_ELE_asi);
-      break;
-    case NVPTXISD::LDUV2:
-      Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                               NVPTX::INT_PTX_LDU_G_v2i8_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2i16_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2i32_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2i64_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2f32_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2f64_ELE_asi);
-      break;
-    case NVPTXISD::LoadV4:
-      Opcode = pickOpcodeForVT(
-          EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_asi,
-          NVPTX::INT_PTX_LDG_G_v4i16_ELE_asi,
-          NVPTX::INT_PTX_LDG_G_v4i32_ELE_asi, std::nullopt,
-          NVPTX::INT_PTX_LDG_G_v4f32_ELE_asi, std::nullopt);
-      break;
-    case NVPTXISD::LDUV4:
-      Opcode = pickOpcodeForVT(
-          EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_asi,
-          NVPTX::INT_PTX_LDU_G_v4i16_ELE_asi,
-          NVPTX::INT_PTX_LDU_G_v4i32_ELE_asi, std::nullopt,
-          NVPTX::INT_PTX_LDU_G_v4f32_ELE_asi, std::nullopt);
-      break;
-    }
-  } else {
-    if (TM.is64Bit()) {
-      SelectADDRri64(Op1.getNode(), Op1, Base, Offset);
-      switch (N->getOpcode()) {
-      default:
-        return false;
-      case ISD::LOAD:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_i8ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_i16ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_i32ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_i64ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_f32ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_f64ari64);
-        break;
-      case ISD::INTRINSIC_W_CHAIN:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_i8ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_i16ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_i32ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_i64ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_f32ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_f64ari64);
-        break;
-      case NVPTXISD::LoadV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                     NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari64);
-        break;
-      case NVPTXISD::LDUV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                     NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari64);
-        break;
-      case NVPTXISD::LoadV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari64,
-            NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari64,
-            NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari64, std::nullopt,
-            NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari64, std::nullopt);
-        break;
-      case NVPTXISD::LDUV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari64,
-            NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari64,
-            NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari64, std::nullopt,
-            NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari64, std::nullopt);
-        break;
-      }
-    } else {
-      SelectADDRri(Op1.getNode(), Op1, Base, Offset);
-      switch (N->getOpcode()) {
-      default:
-        return false;
-      case ISD::LOAD:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8ari,
-            NVPTX::INT_PTX_LDG_GLOBAL_i16ari, NVPTX::INT_PTX_LDG_GLOBAL_i32ari,
-            NVPTX::INT_PTX_LDG_GLOBAL_i64ari, NVPTX::INT_PTX_LDG_GLOBAL_f32ari,
-            NVPTX::INT_PTX_LDG_GLOBAL_f64ari);
-        break;
-      case ISD::INTRINSIC_W_CHAIN:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8ari,
-            NVPTX::INT_PTX_LDU_GLOBAL_i16ari, NVPTX::INT_PTX_LDU_GLOBAL_i32ari,
-            NVPTX::INT_PTX_LDU_GLOBAL_i64ari, NVPTX::INT_PTX_LDU_GLOBAL_f32ari,
-            NVPTX::INT_PTX_LDU_GLOBAL_f64ari);
-        break;
-      case NVPTXISD::LoadV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari32);
-        break;
-      case NVPTXISD::LDUV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari32);
-        break;
-      case NVPTXISD::LoadV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari32,
-            NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari32,
-            NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari32, std::nullopt,
-            NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari32, std::nullopt);
-        break;
-      case NVPTXISD::LDUV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari32,
-            NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari32,
-            NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari32, std::nullopt,
-            NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari32, std::nullopt);
-        break;
-      }
-    }
+  switch (N->getOpcode()) {
+  default:
+    return false;
+  case ISD::LOAD:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
+        NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
+        NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
+        NVPTX::INT_PTX_LDG_GLOBAL_f64);
+    break;
+  case ISD::INTRINSIC_W_CHAIN:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
+        NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
+        NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
+        NVPTX::INT_PTX_LDU_GLOBAL_f64);
+    break;
+  case NVPTXISD::LoadV2:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
+        NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
+        NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
+        NVPTX::INT_PTX_LDG_G_v2f64_ELE);
+    break;
+  case NVPTXISD::LDUV2:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
+        NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
+        NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
+        NVPTX::INT_PTX_LDU_G_v2f64_ELE);
+    break;
+  case NVPTXISD::LoadV4:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
+        NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
+        std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
+    break;
+  case NVPTXISD::LDUV4:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
+        NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
+        std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
+    break;
   }
   if (!Opcode)
     return false;
+
+  SDLoc DL(N);
+  SDValue Base, Offset;
+  SelectADDR(Op1, Base, Offset);
   SDValue Ops[] = {Base, Offset, Chain};
-  LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
+  SDNode *LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
 
   // For automatic generation of LDG (through SelectLoad[Vector], not the
   // intrinsics), we may have an extending load like:
@@ -1424,8 +1250,6 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
 
   // Address Space Setting
   unsigned int CodeAddrSpace = getCodeAddrSpace(ST);
-  unsigned int PointerSize =
-      CurDAG->getDataLayout().getPointerSizeInBits(ST->getAddressSpace());
 
   SDLoc DL(N);
   SDValue Chain = ST->getChain();
@@ -1450,38 +1274,28 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
 
   // Create the machine instruction DAG
   SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal();
-  SDValue BasePtr = ST->getBasePtr();
+
   SDValue Offset, Base;
-  std::optional<unsigned> Opcode;
-  MVT::SimpleValueType SourceVT =
+  SelectADDR(ST->getBasePtr(), Base, Offset);
+
+  SDValue Ops[] = {Value,
+                   getI32Imm(Ordering, DL),
+                   getI32Imm(Scope, DL),
+                   getI32Imm(CodeAddrSpace, DL),
+                   getI32Imm(VecType, DL),
+                   getI32Imm(ToType, DL),
+                   getI32Imm(ToTypeWidth, DL),
+                   Base,
+                   Offset,
+                   Chain};
+
+  const MVT::SimpleValueType SourceVT =
       Value.getNode()->getSimpleValueType(0).SimpleTy;
-
-  SmallVector<SDValue, 12> Ops(
-      {Value, getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
-       getI32Imm(CodeAddrSpace, DL), getI32Imm(VecType, DL),
-       getI32Imm(ToType, DL), getI32Imm(ToTypeWidth, DL)});
-
-  if (SelectADDRsi(BasePtr.getNode(), BasePtr, Base, Offset)) {
-    Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_asi, NVPTX::ST_i16_asi,
-                             NVPTX::ST_i32_asi, NVPTX::ST_i64_asi,
-                             NVPTX::ST_f32_asi, NVPTX::ST_f64_asi);
-  } else {
-    if (PointerSize == 64) {
-      SelectADDRri64(BasePtr.getNode(), BasePtr, Base, Offset);
-      Opcode =
-          pickOpcodeForVT(SourceVT, NVPTX::ST_i8_ari_64, NVPTX::ST_i16_ari_64,
-                          NVPTX::ST_i32_ari_64, NVPTX::ST_i64_ari_64,
-                          NVPTX::ST_f32_ari_64, NVPTX::ST_f64_ari_64);
-    } else {
-      SelectADDRri(BasePtr.getNode(), BasePtr, Base, Offset);
-      Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_ari, NVPTX::ST_i16_ari,
-                      ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/129102


More information about the llvm-commits mailing list