[llvm] [NVPTX] Combine addressing-mode variants of ld, st, wmma (PR #129102)

Kevin McAfee via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 27 15:00:52 PST 2025


================
@@ -1197,176 +1141,58 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
   SDValue Chain = N->getOperand(0);
 
   std::optional<unsigned> Opcode;
-  SDLoc DL(N);
-  SDNode *LD;
-  SDValue Base, Offset;
-
-  if (SelectADDRsi(Op1.getNode(), Op1, Base, Offset)) {
-    switch (N->getOpcode()) {
-    default:
-      return false;
-    case ISD::LOAD:
-      Opcode = pickOpcodeForVT(
-          EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8asi,
-          NVPTX::INT_PTX_LDG_GLOBAL_i16asi, NVPTX::INT_PTX_LDG_GLOBAL_i32asi,
-          NVPTX::INT_PTX_LDG_GLOBAL_i64asi, NVPTX::INT_PTX_LDG_GLOBAL_f32asi,
-          NVPTX::INT_PTX_LDG_GLOBAL_f64asi);
-      break;
-    case ISD::INTRINSIC_W_CHAIN:
-      Opcode = pickOpcodeForVT(
-          EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8asi,
-          NVPTX::INT_PTX_LDU_GLOBAL_i16asi, NVPTX::INT_PTX_LDU_GLOBAL_i32asi,
-          NVPTX::INT_PTX_LDU_GLOBAL_i64asi, NVPTX::INT_PTX_LDU_GLOBAL_f32asi,
-          NVPTX::INT_PTX_LDU_GLOBAL_f64asi);
-      break;
-    case NVPTXISD::LoadV2:
-      Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                               NVPTX::INT_PTX_LDG_G_v2i8_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2i16_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2i32_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2i64_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2f32_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2f64_ELE_asi);
-      break;
-    case NVPTXISD::LDUV2:
-      Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                               NVPTX::INT_PTX_LDU_G_v2i8_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2i16_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2i32_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2i64_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2f32_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2f64_ELE_asi);
-      break;
-    case NVPTXISD::LoadV4:
-      Opcode = pickOpcodeForVT(
-          EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_asi,
-          NVPTX::INT_PTX_LDG_G_v4i16_ELE_asi,
-          NVPTX::INT_PTX_LDG_G_v4i32_ELE_asi, std::nullopt,
-          NVPTX::INT_PTX_LDG_G_v4f32_ELE_asi, std::nullopt);
-      break;
-    case NVPTXISD::LDUV4:
-      Opcode = pickOpcodeForVT(
-          EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_asi,
-          NVPTX::INT_PTX_LDU_G_v4i16_ELE_asi,
-          NVPTX::INT_PTX_LDU_G_v4i32_ELE_asi, std::nullopt,
-          NVPTX::INT_PTX_LDU_G_v4f32_ELE_asi, std::nullopt);
-      break;
-    }
-  } else {
-    if (TM.is64Bit()) {
-      SelectADDRri64(Op1.getNode(), Op1, Base, Offset);
-      switch (N->getOpcode()) {
-      default:
-        return false;
-      case ISD::LOAD:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_i8ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_i16ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_i32ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_i64ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_f32ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_f64ari64);
-        break;
-      case ISD::INTRINSIC_W_CHAIN:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_i8ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_i16ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_i32ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_i64ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_f32ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_f64ari64);
-        break;
-      case NVPTXISD::LoadV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                     NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari64);
-        break;
-      case NVPTXISD::LDUV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                     NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari64);
-        break;
-      case NVPTXISD::LoadV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari64,
-            NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari64,
-            NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari64, std::nullopt,
-            NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari64, std::nullopt);
-        break;
-      case NVPTXISD::LDUV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari64,
-            NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari64,
-            NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari64, std::nullopt,
-            NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari64, std::nullopt);
-        break;
-      }
-    } else {
-      SelectADDRri(Op1.getNode(), Op1, Base, Offset);
-      switch (N->getOpcode()) {
-      default:
-        return false;
-      case ISD::LOAD:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8ari,
-            NVPTX::INT_PTX_LDG_GLOBAL_i16ari, NVPTX::INT_PTX_LDG_GLOBAL_i32ari,
-            NVPTX::INT_PTX_LDG_GLOBAL_i64ari, NVPTX::INT_PTX_LDG_GLOBAL_f32ari,
-            NVPTX::INT_PTX_LDG_GLOBAL_f64ari);
-        break;
-      case ISD::INTRINSIC_W_CHAIN:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8ari,
-            NVPTX::INT_PTX_LDU_GLOBAL_i16ari, NVPTX::INT_PTX_LDU_GLOBAL_i32ari,
-            NVPTX::INT_PTX_LDU_GLOBAL_i64ari, NVPTX::INT_PTX_LDU_GLOBAL_f32ari,
-            NVPTX::INT_PTX_LDU_GLOBAL_f64ari);
-        break;
-      case NVPTXISD::LoadV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari32);
-        break;
-      case NVPTXISD::LDUV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari32);
-        break;
-      case NVPTXISD::LoadV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari32,
-            NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari32,
-            NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari32, std::nullopt,
-            NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari32, std::nullopt);
-        break;
-      case NVPTXISD::LDUV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari32,
-            NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari32,
-            NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari32, std::nullopt,
-            NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari32, std::nullopt);
-        break;
-      }
-    }
+  switch (N->getOpcode()) {
+  default:
+    return false;
+  case ISD::LOAD:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
+        NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
+        NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
+        NVPTX::INT_PTX_LDG_GLOBAL_f64);
+    break;
+  case ISD::INTRINSIC_W_CHAIN:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
+        NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
+        NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
+        NVPTX::INT_PTX_LDU_GLOBAL_f64);
+    break;
+  case NVPTXISD::LoadV2:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
+        NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
+        NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
+        NVPTX::INT_PTX_LDG_G_v2f64_ELE);
+    break;
+  case NVPTXISD::LDUV2:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
+        NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
+        NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
+        NVPTX::INT_PTX_LDU_G_v2f64_ELE);
+    break;
+  case NVPTXISD::LoadV4:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
+        NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
+        std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
+    break;
+  case NVPTXISD::LDUV4:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
+        NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
+        std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
+    break;
   }
   if (!Opcode)
     return false;
+
+  SDLoc DL(N);
+  SDValue Base, Offset;
+  SelectADDR(Op1, Base, Offset);
----------------
kalxr wrote:

Tiny nit - seems like in all the other cases we call `SelectADDR` before the opcode switch, would prefer consistency unless there's a reason for it

https://github.com/llvm/llvm-project/pull/129102


More information about the llvm-commits mailing list