[llvm] [NVPTX] Vectorize and lower 256-bit global loads/stores for sm_100+/ptx88+ (PR #139292)
via llvm-commits
llvm-commits at lists.llvm.org
Fri May 9 10:02:25 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Drew Kersnar (dakersnar)
<details>
<summary>Changes</summary>
PTX 8.8+ introduces 256-bit-wide vector loads/stores under certain conditions. This change extends the backend to lower these loads/stores. It also overrides getLoadStoreVecRegBitWidth for NVPTX, allowing the LoadStoreVectorizer to create these wider vector operations.
See the spec for the three relevant PTX instructions here:
- https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld
- https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld-global-nc
- https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st
---
Patch is 183.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139292.diff
15 Files Affected:
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+3)
- (modified) llvm/lib/Target/NVPTX/NVPTX.h (+2-1)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+72-9)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+54-5)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+2)
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+27-6)
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+10)
- (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+3)
- (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp (+7)
- (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h (+2)
- (added) llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll (+520)
- (added) llvm/test/CodeGen/NVPTX/load-store-256-addressing-invariant.ll (+549)
- (added) llvm/test/CodeGen/NVPTX/load-store-256-addressing.ll (+543)
- (added) llvm/test/CodeGen/NVPTX/load-store-vectors-256.ll (+1442)
- (added) llvm/test/Transforms/LoadStoreVectorizer/NVPTX/256-bit.ll (+728)
``````````diff
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 0b137250e4e59..ab1c3c19168af 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -319,6 +319,9 @@ void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
case NVPTX::PTXLdStInstCode::V4:
O << ".v4";
return;
+ case NVPTX::PTXLdStInstCode::V8:
+ O << ".v8";
+ return;
}
// TODO: evaluate whether cases not covered by this switch are bugs
return;
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 83090ab720c73..2468b8f43ae94 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -199,7 +199,8 @@ enum FromType {
enum VecType {
Scalar = 1,
V2 = 2,
- V4 = 4
+ V4 = 4,
+ V8 = 8
};
} // namespace PTXLdStInstCode
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 6f6084b99dda2..74594837d92cc 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -129,6 +129,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
return;
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
+ case NVPTXISD::LoadV8:
if (tryLoadVector(N))
return;
break;
@@ -139,6 +140,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
break;
case NVPTXISD::StoreV2:
case NVPTXISD::StoreV4:
+ case NVPTXISD::StoreV8:
if (tryStoreVector(N))
return;
break;
@@ -1195,6 +1197,12 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
FromTypeWidth = TotalWidth / 4;
VecType = NVPTX::PTXLdStInstCode::V4;
break;
+ case NVPTXISD::LoadV8:
+ if (!Subtarget->has256BitMaskedLoadStore())
+ return false;
+ FromTypeWidth = TotalWidth / 8;
+ VecType = NVPTX::PTXLdStInstCode::V8;
+ break;
default:
return false;
}
@@ -1205,7 +1213,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
}
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
- FromTypeWidth <= 128 && TotalWidth <= 128 && "Invalid width for load");
+ FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
SDValue Offset, Base;
SelectADDR(N->getOperand(1), Base, Offset);
@@ -1230,9 +1238,22 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
NVPTX::LDV_f32_v2, NVPTX::LDV_f64_v2);
break;
case NVPTXISD::LoadV4:
- Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
- NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, std::nullopt,
- NVPTX::LDV_f32_v4, std::nullopt);
+ Opcode =
+ pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
+ NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4,
+ NVPTX::LDV_f32_v4, NVPTX::LDV_f64_v4);
+ break;
+ case NVPTXISD::LoadV8:
+ switch (EltVT.getSimpleVT().SimpleTy) {
+ case MVT::i32:
+ Opcode = NVPTX::LDV_i32_v8;
+ break;
+ case MVT::f32:
+ Opcode = NVPTX::LDV_f32_v8;
+ break;
+ default:
+ return false;
+ }
break;
}
if (!Opcode)
@@ -1328,7 +1349,8 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
Opcode = pickOpcodeForVT(
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
- std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
+ NVPTX::INT_PTX_LDG_G_v4i64_ELE, NVPTX::INT_PTX_LDG_G_v4f32_ELE,
+ NVPTX::INT_PTX_LDG_G_v4f64_ELE);
break;
case NVPTXISD::LDUV4:
Opcode = pickOpcodeForVT(
@@ -1336,6 +1358,24 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
break;
+ case NVPTXISD::LoadV8:
+ switch (EltVT.getSimpleVT().SimpleTy) {
+ case MVT::i32:
+ Opcode = NVPTX::INT_PTX_LDG_G_v8i32_ELE;
+ break;
+ case MVT::f32:
+ Opcode = NVPTX::INT_PTX_LDG_G_v8f32_ELE;
+ break;
+ case MVT::v2i16:
+ case MVT::v2f16:
+ case MVT::v2bf16:
+ case MVT::v4i8:
+ Opcode = NVPTX::INT_PTX_LDG_G_v8i32_ELE;
+ break;
+ default:
+ return false;
+ }
+ break;
}
if (!Opcode)
return false;
@@ -1502,6 +1542,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
N2 = N->getOperand(5);
ToTypeWidth = TotalWidth / 4;
break;
+ case NVPTXISD::StoreV8:
+ if (!Subtarget->has256BitMaskedLoadStore())
+ return false;
+ VecType = NVPTX::PTXLdStInstCode::V8;
+ Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3),
+ N->getOperand(4), N->getOperand(5), N->getOperand(6),
+ N->getOperand(7), N->getOperand(8)});
+ N2 = N->getOperand(9);
+ ToTypeWidth = TotalWidth / 8;
+ break;
default:
return false;
}
@@ -1512,7 +1562,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
}
assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 && ToTypeWidth <= 128 &&
- TotalWidth <= 128 && "Invalid width for store");
+ TotalWidth <= 256 && "Invalid width for store");
SDValue Offset, Base;
SelectADDR(N2, Base, Offset);
@@ -1533,9 +1583,22 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
NVPTX::STV_f32_v2, NVPTX::STV_f64_v2);
break;
case NVPTXISD::StoreV4:
- Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
- NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, std::nullopt,
- NVPTX::STV_f32_v4, std::nullopt);
+ Opcode =
+ pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
+ NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, NVPTX::STV_i64_v4,
+ NVPTX::STV_f32_v4, NVPTX::STV_f64_v4);
+ break;
+ case NVPTXISD::StoreV8:
+ switch (EltVT.getSimpleVT().SimpleTy) {
+ case MVT::i32:
+ Opcode = NVPTX::STV_i32_v8;
+ break;
+ case MVT::f32:
+ Opcode = NVPTX::STV_f32_v8;
+ break;
+ default:
+ return false;
+ }
break;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 3769aae7b620f..d7883b5d526aa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -162,6 +162,14 @@ static bool IsPTXVectorType(MVT VT) {
case MVT::v2f32:
case MVT::v4f32:
case MVT::v2f64:
+ case MVT::v4i64:
+ case MVT::v4f64:
+ case MVT::v8i32:
+ case MVT::v8f32:
+ case MVT::v16f16: // <8 x f16x2>
+ case MVT::v16bf16: // <8 x bf16x2>
+ case MVT::v16i16: // <8 x i16x2>
+ case MVT::v32i8: // <8 x i8x4>
return true;
}
}
@@ -179,7 +187,7 @@ static bool Is16bitsType(MVT VT) {
// - unsigned int NumElts - The number of elements in the final vector
// - EVT EltVT - The type of the elements in the final vector
static std::optional<std::pair<unsigned int, MVT>>
-getVectorLoweringShape(EVT VectorEVT) {
+getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
if (!VectorEVT.isSimple())
return std::nullopt;
const MVT VectorVT = VectorEVT.getSimpleVT();
@@ -199,6 +207,15 @@ getVectorLoweringShape(EVT VectorEVT) {
switch (VectorVT.SimpleTy) {
default:
return std::nullopt;
+ case MVT::v4i64:
+ case MVT::v4f64:
+ case MVT::v8i32:
+ case MVT::v8f32:
+ // This is a "native" vector type iff the address space is global
+ // and the target supports 256-bit loads/stores
+ if (!CanLowerTo256Bit)
+ return std::nullopt;
+ LLVM_FALLTHROUGH;
case MVT::v2i8:
case MVT::v2i16:
case MVT::v2i32:
@@ -215,6 +232,15 @@ getVectorLoweringShape(EVT VectorEVT) {
case MVT::v4f32:
// This is a "native" vector type
return std::pair(NumElts, EltVT);
+ case MVT::v16f16: // <8 x f16x2>
+ case MVT::v16bf16: // <8 x bf16x2>
+ case MVT::v16i16: // <8 x i16x2>
+ case MVT::v32i8: // <8 x i8x4>
+ // This can be upsized into a "native" vector type iff the address space is
+ // global and the target supports 256-bit loads/stores.
+ if (!CanLowerTo256Bit)
+ return std::nullopt;
+ LLVM_FALLTHROUGH;
case MVT::v8i8: // <2 x i8x4>
case MVT::v8f16: // <4 x f16x2>
case MVT::v8bf16: // <4 x bf16x2>
@@ -1070,10 +1096,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::ProxyReg)
MAKE_CASE(NVPTXISD::LoadV2)
MAKE_CASE(NVPTXISD::LoadV4)
+ MAKE_CASE(NVPTXISD::LoadV8)
MAKE_CASE(NVPTXISD::LDUV2)
MAKE_CASE(NVPTXISD::LDUV4)
MAKE_CASE(NVPTXISD::StoreV2)
MAKE_CASE(NVPTXISD::StoreV4)
+ MAKE_CASE(NVPTXISD::StoreV8)
MAKE_CASE(NVPTXISD::FSHL_CLAMP)
MAKE_CASE(NVPTXISD::FSHR_CLAMP)
MAKE_CASE(NVPTXISD::BFE)
@@ -3201,7 +3229,12 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
if (ValVT != MemVT)
return SDValue();
- const auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
+ // 256-bit vectors are only allowed iff the address is global
+ // and the target supports 256-bit loads/stores
+ unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace();
+ bool CanLowerTo256Bit =
+ AddrSpace == ADDRESS_SPACE_GLOBAL && STI.has256BitMaskedLoadStore();
+ const auto NumEltsAndEltVT = getVectorLoweringShape(ValVT, CanLowerTo256Bit);
if (!NumEltsAndEltVT)
return SDValue();
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
@@ -3229,6 +3262,9 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
case 4:
Opcode = NVPTXISD::StoreV4;
break;
+ case 8:
+ Opcode = NVPTXISD::StoreV8;
+ break;
}
SmallVector<SDValue, 8> Ops;
@@ -5765,7 +5801,8 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
- SmallVectorImpl<SDValue> &Results) {
+ SmallVectorImpl<SDValue> &Results,
+ bool TargetHas256BitVectorLoadStore) {
LoadSDNode *LD = cast<LoadSDNode>(N);
const EVT ResVT = LD->getValueType(0);
const EVT MemVT = LD->getMemoryVT();
@@ -5775,7 +5812,12 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
if (ResVT != MemVT)
return;
- const auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
+ // 256-bit vectors are only allowed iff the address is global
+ // and the target supports 256-bit loads/stores
+ unsigned AddrSpace = cast<MemSDNode>(N)->getAddressSpace();
+ bool CanLowerTo256Bit =
+ AddrSpace == ADDRESS_SPACE_GLOBAL && TargetHas256BitVectorLoadStore;
+ const auto NumEltsAndEltVT = getVectorLoweringShape(ResVT, CanLowerTo256Bit);
if (!NumEltsAndEltVT)
return;
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
@@ -5812,6 +5854,13 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
DAG.getVTList({LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT, MVT::Other});
break;
}
+ case 8: {
+ Opcode = NVPTXISD::LoadV8;
+ EVT ListVTs[] = {LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT,
+ LoadEltVT, LoadEltVT, LoadEltVT, MVT::Other};
+ LdResVTs = DAG.getVTList(ListVTs);
+ break;
+ }
}
SDLoc DL(LD);
@@ -6084,7 +6133,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
ReplaceBITCAST(N, DAG, Results);
return;
case ISD::LOAD:
- ReplaceLoadVector(N, DAG, Results);
+ ReplaceLoadVector(N, DAG, Results, STI.has256BitMaskedLoadStore());
return;
case ISD::INTRINSIC_W_CHAIN:
ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 7a8bf3bf33a94..3dff83d74538b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -84,10 +84,12 @@ enum NodeType : unsigned {
FIRST_MEMORY_OPCODE,
LoadV2 = FIRST_MEMORY_OPCODE,
LoadV4,
+ LoadV8,
LDUV2, // LDU.v2
LDUV4, // LDU.v4
StoreV2,
StoreV4,
+ StoreV8,
LoadParam,
LoadParamV2,
LoadParamV4,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index a384cb79d645a..d0f3fb4ec1c1d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2425,7 +2425,7 @@ let mayStore=1, hasSideEffects=0 in {
// The following is used only in and after vector elementizations. Vector
// elementization happens at the machine instruction level, so the following
// instructions never appear in the DAG.
-multiclass LD_VEC<NVPTXRegClass regclass> {
+multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
def _v2 : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2),
(ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
@@ -2438,17 +2438,27 @@ multiclass LD_VEC<NVPTXRegClass regclass> {
LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
"ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr];", []>;
+ if support_v8 then {
+ def _v8 : NVPTXInst<
+ (outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
+ regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
+ (ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign,
+ i32imm:$fromWidth, ADDR:$addr),
+ "ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
+ "\t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, "
+ "[$addr];", []>;
+ }
}
let mayLoad=1, hasSideEffects=0 in {
defm LDV_i8 : LD_VEC<Int16Regs>;
defm LDV_i16 : LD_VEC<Int16Regs>;
- defm LDV_i32 : LD_VEC<Int32Regs>;
+ defm LDV_i32 : LD_VEC<Int32Regs, true>;
defm LDV_i64 : LD_VEC<Int64Regs>;
- defm LDV_f32 : LD_VEC<Float32Regs>;
+ defm LDV_f32 : LD_VEC<Float32Regs, true>;
defm LDV_f64 : LD_VEC<Float64Regs>;
}
-multiclass ST_VEC<NVPTXRegClass regclass> {
+multiclass ST_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
def _v2 : NVPTXInst<
(outs),
(ins regclass:$src1, regclass:$src2, LdStCode:$sem, LdStCode:$scope,
@@ -2463,14 +2473,25 @@ multiclass ST_VEC<NVPTXRegClass regclass> {
LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t[$addr], {{$src1, $src2, $src3, $src4}};", []>;
+ if support_v8 then {
+ def _v8 : NVPTXInst<
+ (outs),
+ (ins regclass:$src1, regclass:$src2, regclass:$src3, regclass:$src4,
+ regclass:$src5, regclass:$src6, regclass:$src7, regclass:$src8,
+ LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign,
+ i32imm:$fromWidth, ADDR:$addr),
+ "st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
+ "\t[$addr], "
+ "{{$src1, $src2, $src3, $src4, $src5, $src6, $src7, $src8}};", []>;
+ }
}
let mayStore=1, hasSideEffects=0 in {
defm STV_i8 : ST_VEC<Int16Regs>;
defm STV_i16 : ST_VEC<Int16Regs>;
- defm STV_i32 : ST_VEC<Int32Regs>;
+ defm STV_i32 : ST_VEC<Int32Regs, true>;
defm STV_i64 : ST_VEC<Int64Regs>;
- defm STV_f32 : ST_VEC<Float32Regs>;
+ defm STV_f32 : ST_VEC<Float32Regs, true>;
defm STV_f64 : ST_VEC<Float64Regs>;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 7b139d7b79e7d..cdbf29c140429 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2388,6 +2388,12 @@ class VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> :
(ins ADDR:$src),
"ld.global.nc.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
+class VLDG_G_ELE_V8<string TyStr, NVPTXRegClass regclass> :
+ NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
+ regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
+ (ins ADDR:$src),
+ "ld.global.nc.v8." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];", []>;
+
// FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
def INT_PTX_LDG_G_v2i8_ELE : VLDG_G_ELE_V2<"u8", Int16Regs>;
def INT_PTX_LDG_G_v2i16_ELE : VLDG_G_ELE_V2<"u16", Int16Regs>;
@@ -2401,6 +2407,10 @@ def INT_PTX_LDG_G_v4i16_ELE : VLDG_G_ELE_V4<"u16", Int16Regs>;
def INT_PTX_LDG_G_v4i32_ELE : VLDG_G_ELE_V4<"u32", Int32Regs>;
def INT_PTX_LDG_G_v4f32_ELE : VLDG_G_ELE_V4<"f32", Float32Regs>;
+def INT_PTX_LDG_G_v4i64_ELE : VLDG_G_ELE_V4<"u64", Int64Regs>;
+def INT_PTX_LDG_G_v4f64_ELE : VLDG_G_ELE_V4<"f64", Float64Regs>;
+def INT_PTX_LDG_G_v8i32_ELE : VLDG_G_ELE_V8<"u32", Int32Regs>;
+def INT_PTX_LDG_G_v8f32_ELE : VLDG_G_ELE_V8<"f32", Float32Regs>;
multiclass NG_TO_G<string Str, bit Supports32 = 1, list<Predicate> Preds = []> {
if Supports32 then
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 0a4fc8d1435be..5552bba728160 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -72,6 +72,9 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
const SelectionDAGTargetInfo *getSelectionDAGInfo() const override;
+ bool has256BitMaskedLoadStore() const {
+ return SmVersion >= 100 && PTXVersion >= 88;
+ }
bool hasAtomAddF64() const { return SmVersion >= 60; }
bool hasAtomScope() const { return SmVersion >= 60; }
bool hasAtomBitwise64() const { return SmVersion >= 32; }
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 66c5139f8c2cc..1d8525fd4656f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -591,6 +591,13 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
return nullptr;
}
+unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
+ // 256 bit loads/stores are currently only supported for global address space
+ if (AddrSpace == ADDRESS_SPACE_GLOBAL && ST->has256BitMaskedLoadStore())
+ return 256;
+ return 128;
+}
+
unsigned NVPTXTTIImpl::getAssumedAddrSpace(const Value *V) const {
if (isa<AllocaInst>(V))
return ADDRESS_SPACE_LOCAL;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index a9bd5a0d01043..98aea4e535f0a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -173,6 +173,8 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
Intrinsic::ID IID) const override;
+ unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
+
Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
Value *NewV) const override;
unsigned getAssumedAddrSpace(const Value *V) const override;
diff --git a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
new file mode 100644
index 0000000000000..f4abcb37aa894
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
@@ -0,0 +1,520 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx87 -verify-machineinstrs | FileCheck %s -check-prefixes=SM90
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx87 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx88 -verify-machineinstrs | FileCheck %s -check-prefixes=SM100
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; For 256-bit vectors, check that invariant loads from the
+; global addrspace are lowered to ld.globa...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/139292
More information about the llvm-commits
mailing list