[llvm] [NVPTX] Vectorize and lower 256-bit global loads/stores for sm_100+/ptx88+ (PR #139292)

via llvm-commits llvm-commits at lists.llvm.org
Fri May 9 10:02:25 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Drew Kersnar (dakersnar)

<details>
<summary>Changes</summary>

PTX 8.8+ introduces 256-bit-wide vector loads/stores under certain conditions. This change extends the backend to lower these loads/stores. It also overrides getLoadStoreVecRegBitWidth for NVPTX, allowing the LoadStoreVectorizer to create these wider vector operations.

See the spec for the three relevant PTX instructions here:
- https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld
- https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld-global-nc
- https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st

---

Patch is 183.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139292.diff


15 Files Affected:

- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+3) 
- (modified) llvm/lib/Target/NVPTX/NVPTX.h (+2-1) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+72-9) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+54-5) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+2) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+27-6) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+10) 
- (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+3) 
- (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp (+7) 
- (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h (+2) 
- (added) llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll (+520) 
- (added) llvm/test/CodeGen/NVPTX/load-store-256-addressing-invariant.ll (+549) 
- (added) llvm/test/CodeGen/NVPTX/load-store-256-addressing.ll (+543) 
- (added) llvm/test/CodeGen/NVPTX/load-store-vectors-256.ll (+1442) 
- (added) llvm/test/Transforms/LoadStoreVectorizer/NVPTX/256-bit.ll (+728) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 0b137250e4e59..ab1c3c19168af 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -319,6 +319,9 @@ void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
     case NVPTX::PTXLdStInstCode::V4:
       O << ".v4";
       return;
+    case NVPTX::PTXLdStInstCode::V8:
+      O << ".v8";
+      return;
     }
     // TODO: evaluate whether cases not covered by this switch are bugs
     return;
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 83090ab720c73..2468b8f43ae94 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -199,7 +199,8 @@ enum FromType {
 enum VecType {
   Scalar = 1,
   V2 = 2,
-  V4 = 4
+  V4 = 4,
+  V8 = 8
 };
 } // namespace PTXLdStInstCode
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 6f6084b99dda2..74594837d92cc 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -129,6 +129,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
     return;
   case NVPTXISD::LoadV2:
   case NVPTXISD::LoadV4:
+  case NVPTXISD::LoadV8:
     if (tryLoadVector(N))
       return;
     break;
@@ -139,6 +140,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
     break;
   case NVPTXISD::StoreV2:
   case NVPTXISD::StoreV4:
+  case NVPTXISD::StoreV8:
     if (tryStoreVector(N))
       return;
     break;
@@ -1195,6 +1197,12 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
     FromTypeWidth = TotalWidth / 4;
     VecType = NVPTX::PTXLdStInstCode::V4;
     break;
+  case NVPTXISD::LoadV8:
+    if (!Subtarget->has256BitMaskedLoadStore())
+      return false;
+    FromTypeWidth = TotalWidth / 8;
+    VecType = NVPTX::PTXLdStInstCode::V8;
+    break;
   default:
     return false;
   }
@@ -1205,7 +1213,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
   }
 
   assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
-         FromTypeWidth <= 128 && TotalWidth <= 128 && "Invalid width for load");
+         FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
 
   SDValue Offset, Base;
   SelectADDR(N->getOperand(1), Base, Offset);
@@ -1230,9 +1238,22 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
                         NVPTX::LDV_f32_v2, NVPTX::LDV_f64_v2);
     break;
   case NVPTXISD::LoadV4:
-    Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
-                             NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, std::nullopt,
-                             NVPTX::LDV_f32_v4, std::nullopt);
+    Opcode =
+        pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
+                        NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4,
+                        NVPTX::LDV_f32_v4, NVPTX::LDV_f64_v4);
+    break;
+  case NVPTXISD::LoadV8:
+    switch (EltVT.getSimpleVT().SimpleTy) {
+    case MVT::i32:
+      Opcode = NVPTX::LDV_i32_v8;
+      break;
+    case MVT::f32:
+      Opcode = NVPTX::LDV_f32_v8;
+      break;
+    default:
+      return false;
+    }
     break;
   }
   if (!Opcode)
@@ -1328,7 +1349,8 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
     Opcode = pickOpcodeForVT(
         EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
         NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
-        std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
+        NVPTX::INT_PTX_LDG_G_v4i64_ELE, NVPTX::INT_PTX_LDG_G_v4f32_ELE,
+        NVPTX::INT_PTX_LDG_G_v4f64_ELE);
     break;
   case NVPTXISD::LDUV4:
     Opcode = pickOpcodeForVT(
@@ -1336,6 +1358,24 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
         NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
         std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
     break;
+  case NVPTXISD::LoadV8:
+    switch (EltVT.getSimpleVT().SimpleTy) {
+    case MVT::i32:
+      Opcode = NVPTX::INT_PTX_LDG_G_v8i32_ELE;
+      break;
+    case MVT::f32:
+      Opcode = NVPTX::INT_PTX_LDG_G_v8f32_ELE;
+      break;
+    case MVT::v2i16:
+    case MVT::v2f16:
+    case MVT::v2bf16:
+    case MVT::v4i8:
+      Opcode = NVPTX::INT_PTX_LDG_G_v8i32_ELE;
+      break;
+    default:
+      return false;
+    }
+    break;
   }
   if (!Opcode)
     return false;
@@ -1502,6 +1542,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
     N2 = N->getOperand(5);
     ToTypeWidth = TotalWidth / 4;
     break;
+  case NVPTXISD::StoreV8:
+    if (!Subtarget->has256BitMaskedLoadStore())
+      return false;
+    VecType = NVPTX::PTXLdStInstCode::V8;
+    Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3),
+                N->getOperand(4), N->getOperand(5), N->getOperand(6),
+                N->getOperand(7), N->getOperand(8)});
+    N2 = N->getOperand(9);
+    ToTypeWidth = TotalWidth / 8;
+    break;
   default:
     return false;
   }
@@ -1512,7 +1562,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
   }
 
   assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 && ToTypeWidth <= 128 &&
-         TotalWidth <= 128 && "Invalid width for store");
+         TotalWidth <= 256 && "Invalid width for store");
 
   SDValue Offset, Base;
   SelectADDR(N2, Base, Offset);
@@ -1533,9 +1583,22 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
                         NVPTX::STV_f32_v2, NVPTX::STV_f64_v2);
     break;
   case NVPTXISD::StoreV4:
-    Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
-                             NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, std::nullopt,
-                             NVPTX::STV_f32_v4, std::nullopt);
+    Opcode =
+        pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
+                        NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, NVPTX::STV_i64_v4,
+                        NVPTX::STV_f32_v4, NVPTX::STV_f64_v4);
+    break;
+  case NVPTXISD::StoreV8:
+    switch (EltVT.getSimpleVT().SimpleTy) {
+    case MVT::i32:
+      Opcode = NVPTX::STV_i32_v8;
+      break;
+    case MVT::f32:
+      Opcode = NVPTX::STV_f32_v8;
+      break;
+    default:
+      return false;
+    }
     break;
   }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 3769aae7b620f..d7883b5d526aa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -162,6 +162,14 @@ static bool IsPTXVectorType(MVT VT) {
   case MVT::v2f32:
   case MVT::v4f32:
   case MVT::v2f64:
+  case MVT::v4i64:
+  case MVT::v4f64:
+  case MVT::v8i32:
+  case MVT::v8f32:
+  case MVT::v16f16:  // <8 x f16x2>
+  case MVT::v16bf16: // <8 x bf16x2>
+  case MVT::v16i16:  // <8 x i16x2>
+  case MVT::v32i8:   // <8 x i8x4>
     return true;
   }
 }
@@ -179,7 +187,7 @@ static bool Is16bitsType(MVT VT) {
 //    - unsigned int NumElts - The number of elements in the final vector
 //    - EVT EltVT - The type of the elements in the final vector
 static std::optional<std::pair<unsigned int, MVT>>
-getVectorLoweringShape(EVT VectorEVT) {
+getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
   if (!VectorEVT.isSimple())
     return std::nullopt;
   const MVT VectorVT = VectorEVT.getSimpleVT();
@@ -199,6 +207,15 @@ getVectorLoweringShape(EVT VectorEVT) {
   switch (VectorVT.SimpleTy) {
   default:
     return std::nullopt;
+  case MVT::v4i64:
+  case MVT::v4f64:
+  case MVT::v8i32:
+  case MVT::v8f32:
+    // This is a "native" vector type iff the address space is global
+    // and the target supports 256-bit loads/stores
+    if (!CanLowerTo256Bit)
+      return std::nullopt;
+    LLVM_FALLTHROUGH;
   case MVT::v2i8:
   case MVT::v2i16:
   case MVT::v2i32:
@@ -215,6 +232,15 @@ getVectorLoweringShape(EVT VectorEVT) {
   case MVT::v4f32:
     // This is a "native" vector type
     return std::pair(NumElts, EltVT);
+  case MVT::v16f16:  // <8 x f16x2>
+  case MVT::v16bf16: // <8 x bf16x2>
+  case MVT::v16i16:  // <8 x i16x2>
+  case MVT::v32i8:   // <8 x i8x4>
+    // This can be upsized into a "native" vector type iff the address space is
+    // global and the target supports 256-bit loads/stores.
+    if (!CanLowerTo256Bit)
+      return std::nullopt;
+    LLVM_FALLTHROUGH;
   case MVT::v8i8:   // <2 x i8x4>
   case MVT::v8f16:  // <4 x f16x2>
   case MVT::v8bf16: // <4 x bf16x2>
@@ -1070,10 +1096,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::ProxyReg)
     MAKE_CASE(NVPTXISD::LoadV2)
     MAKE_CASE(NVPTXISD::LoadV4)
+    MAKE_CASE(NVPTXISD::LoadV8)
     MAKE_CASE(NVPTXISD::LDUV2)
     MAKE_CASE(NVPTXISD::LDUV4)
     MAKE_CASE(NVPTXISD::StoreV2)
     MAKE_CASE(NVPTXISD::StoreV4)
+    MAKE_CASE(NVPTXISD::StoreV8)
     MAKE_CASE(NVPTXISD::FSHL_CLAMP)
     MAKE_CASE(NVPTXISD::FSHR_CLAMP)
     MAKE_CASE(NVPTXISD::BFE)
@@ -3201,7 +3229,12 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
   if (ValVT != MemVT)
     return SDValue();
 
-  const auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
+  // 256-bit vectors are only allowed iff the address is global
+  // and the target supports 256-bit loads/stores
+  unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace();
+  bool CanLowerTo256Bit =
+      AddrSpace == ADDRESS_SPACE_GLOBAL && STI.has256BitMaskedLoadStore();
+  const auto NumEltsAndEltVT = getVectorLoweringShape(ValVT, CanLowerTo256Bit);
   if (!NumEltsAndEltVT)
     return SDValue();
   const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
@@ -3229,6 +3262,9 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
   case 4:
     Opcode = NVPTXISD::StoreV4;
     break;
+  case 8:
+    Opcode = NVPTXISD::StoreV8;
+    break;
   }
 
   SmallVector<SDValue, 8> Ops;
@@ -5765,7 +5801,8 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
 
 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
-                              SmallVectorImpl<SDValue> &Results) {
+                              SmallVectorImpl<SDValue> &Results,
+                              bool TargetHas256BitVectorLoadStore) {
   LoadSDNode *LD = cast<LoadSDNode>(N);
   const EVT ResVT = LD->getValueType(0);
   const EVT MemVT = LD->getMemoryVT();
@@ -5775,7 +5812,12 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
   if (ResVT != MemVT)
     return;
 
-  const auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
+  // 256-bit vectors are only allowed iff the address is global
+  // and the target supports 256-bit loads/stores
+  unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace();
+  bool CanLowerTo256Bit =
+      AddrSpace == ADDRESS_SPACE_GLOBAL && TargetHas256BitVectorLoadStore;
+  const auto NumEltsAndEltVT = getVectorLoweringShape(ResVT, CanLowerTo256Bit);
   if (!NumEltsAndEltVT)
     return;
   const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
@@ -5812,6 +5854,13 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
         DAG.getVTList({LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT, MVT::Other});
     break;
   }
+  case 8: {
+    Opcode = NVPTXISD::LoadV8;
+    EVT ListVTs[] = {LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT,
+                     LoadEltVT, LoadEltVT, LoadEltVT, MVT::Other};
+    LdResVTs = DAG.getVTList(ListVTs);
+    break;
+  }
   }
   SDLoc DL(LD);
 
@@ -6084,7 +6133,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
     ReplaceBITCAST(N, DAG, Results);
     return;
   case ISD::LOAD:
-    ReplaceLoadVector(N, DAG, Results);
+    ReplaceLoadVector(N, DAG, Results, STI.has256BitMaskedLoadStore());
     return;
   case ISD::INTRINSIC_W_CHAIN:
     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 7a8bf3bf33a94..3dff83d74538b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -84,10 +84,12 @@ enum NodeType : unsigned {
   FIRST_MEMORY_OPCODE,
   LoadV2 = FIRST_MEMORY_OPCODE,
   LoadV4,
+  LoadV8,
   LDUV2, // LDU.v2
   LDUV4, // LDU.v4
   StoreV2,
   StoreV4,
+  StoreV8,
   LoadParam,
   LoadParamV2,
   LoadParamV4,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index a384cb79d645a..d0f3fb4ec1c1d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2425,7 +2425,7 @@ let mayStore=1, hasSideEffects=0 in {
 // The following is used only in and after vector elementizations.  Vector
 // elementization happens at the machine instruction level, so the following
 // instructions never appear in the DAG.
-multiclass LD_VEC<NVPTXRegClass regclass> {
+multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
   def _v2 : NVPTXInst<
     (outs regclass:$dst1, regclass:$dst2),
     (ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
@@ -2438,17 +2438,27 @@ multiclass LD_VEC<NVPTXRegClass regclass> {
          LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
     "ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
     "\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr];", []>;
+  if support_v8 then {
+    def _v8 : NVPTXInst<
+      (outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
+            regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
+      (ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign,
+           i32imm:$fromWidth, ADDR:$addr),
+      "ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
+      "\t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, "
+      "[$addr];", []>;
+  }
 }
 let mayLoad=1, hasSideEffects=0 in {
   defm LDV_i8  : LD_VEC<Int16Regs>;
   defm LDV_i16 : LD_VEC<Int16Regs>;
-  defm LDV_i32 : LD_VEC<Int32Regs>;
+  defm LDV_i32 : LD_VEC<Int32Regs, true>;
   defm LDV_i64 : LD_VEC<Int64Regs>;
-  defm LDV_f32 : LD_VEC<Float32Regs>;
+  defm LDV_f32 : LD_VEC<Float32Regs, true>;
   defm LDV_f64 : LD_VEC<Float64Regs>;
 }
 
-multiclass ST_VEC<NVPTXRegClass regclass> {
+multiclass ST_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
   def _v2 : NVPTXInst<
     (outs),
     (ins regclass:$src1, regclass:$src2, LdStCode:$sem, LdStCode:$scope,
@@ -2463,14 +2473,25 @@ multiclass ST_VEC<NVPTXRegClass regclass> {
          LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
     "st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
     "\t[$addr], {{$src1, $src2, $src3, $src4}};", []>;
+  if support_v8 then {
+    def _v8 : NVPTXInst<
+      (outs),
+      (ins regclass:$src1, regclass:$src2, regclass:$src3, regclass:$src4,
+           regclass:$src5, regclass:$src6, regclass:$src7, regclass:$src8,
+           LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign,
+           i32imm:$fromWidth, ADDR:$addr),
+      "st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
+      "\t[$addr], "
+      "{{$src1, $src2, $src3, $src4, $src5, $src6, $src7, $src8}};", []>;
+  }
 }
 
 let mayStore=1, hasSideEffects=0 in {
   defm STV_i8  : ST_VEC<Int16Regs>;
   defm STV_i16 : ST_VEC<Int16Regs>;
-  defm STV_i32 : ST_VEC<Int32Regs>;
+  defm STV_i32 : ST_VEC<Int32Regs, true>;
   defm STV_i64 : ST_VEC<Int64Regs>;
-  defm STV_f32 : ST_VEC<Float32Regs>;
+  defm STV_f32 : ST_VEC<Float32Regs, true>;
   defm STV_f64 : ST_VEC<Float64Regs>;
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 7b139d7b79e7d..cdbf29c140429 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2388,6 +2388,12 @@ class VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> :
             (ins ADDR:$src),
             "ld.global.nc.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
 
+class VLDG_G_ELE_V8<string TyStr, NVPTXRegClass regclass> :
+  NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
+                  regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
+             (ins ADDR:$src),
+             "ld.global.nc.v8." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];", []>;
+
 // FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
 def INT_PTX_LDG_G_v2i8_ELE : VLDG_G_ELE_V2<"u8", Int16Regs>;
 def INT_PTX_LDG_G_v2i16_ELE : VLDG_G_ELE_V2<"u16", Int16Regs>;
@@ -2401,6 +2407,10 @@ def INT_PTX_LDG_G_v4i16_ELE : VLDG_G_ELE_V4<"u16", Int16Regs>;
 def INT_PTX_LDG_G_v4i32_ELE : VLDG_G_ELE_V4<"u32", Int32Regs>;
 def INT_PTX_LDG_G_v4f32_ELE : VLDG_G_ELE_V4<"f32", Float32Regs>;
 
+def INT_PTX_LDG_G_v4i64_ELE : VLDG_G_ELE_V4<"u64", Int64Regs>;
+def INT_PTX_LDG_G_v4f64_ELE : VLDG_G_ELE_V4<"f64", Float64Regs>;
+def INT_PTX_LDG_G_v8i32_ELE : VLDG_G_ELE_V8<"u32", Int32Regs>;
+def INT_PTX_LDG_G_v8f32_ELE : VLDG_G_ELE_V8<"f32", Float32Regs>;
 
 multiclass NG_TO_G<string Str, bit Supports32 = 1, list<Predicate> Preds = []> {
   if Supports32 then
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 0a4fc8d1435be..5552bba728160 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -72,6 +72,9 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
 
   const SelectionDAGTargetInfo *getSelectionDAGInfo() const override;
 
+  bool has256BitMaskedLoadStore() const {
+    return SmVersion >= 100 && PTXVersion >= 88;
+  }
   bool hasAtomAddF64() const { return SmVersion >= 60; }
   bool hasAtomScope() const { return SmVersion >= 60; }
   bool hasAtomBitwise64() const { return SmVersion >= 32; }
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 66c5139f8c2cc..1d8525fd4656f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -591,6 +591,13 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
   return nullptr;
 }
 
+unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
+  // 256 bit loads/stores are currently only supported for global address space
+  if (AddrSpace == ADDRESS_SPACE_GLOBAL && ST->has256BitMaskedLoadStore())
+    return 256;
+  return 128;
+}
+
 unsigned NVPTXTTIImpl::getAssumedAddrSpace(const Value *V) const {
   if (isa<AllocaInst>(V))
     return ADDRESS_SPACE_LOCAL;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index a9bd5a0d01043..98aea4e535f0a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -173,6 +173,8 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                   Intrinsic::ID IID) const override;
 
+  unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
+
   Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
                                           Value *NewV) const override;
   unsigned getAssumedAddrSpace(const Value *V) const override;
diff --git a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
new file mode 100644
index 0000000000000..f4abcb37aa894
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
@@ -0,0 +1,520 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx87 -verify-machineinstrs | FileCheck %s -check-prefixes=SM90
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx87 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx88 -verify-machineinstrs | FileCheck %s -check-prefixes=SM100
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; For 256-bit vectors, check that invariant loads from the
+; global addrspace are lowered to ld.globa...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/139292


More information about the llvm-commits mailing list