[llvm] 17852de - [NVPTX] Lower LLVM masked vector loads and stores to PTX (#159387)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 25 08:26:20 PST 2025
Author: Drew Kersnar
Date: 2025-11-25T10:26:15-06:00
New Revision: 17852deda7fb9dabb41023e2673025c630b9369d
URL: https://github.com/llvm/llvm-project/commit/17852deda7fb9dabb41023e2673025c630b9369d
DIFF: https://github.com/llvm/llvm-project/commit/17852deda7fb9dabb41023e2673025c630b9369d.diff
LOG: [NVPTX] Lower LLVM masked vector loads and stores to PTX (#159387)
This backend support will allow the LoadStoreVectorizer, in certain
cases, to fill in gaps when creating load/store vectors and generate
LLVM masked load/stores
(https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics). To
accomplish this, changes are separated into two parts. This first part
has the backend lowering and TTI changes, and a follow up PR will have
the LSV generate these intrinsics:
https://github.com/llvm/llvm-project/pull/159388.
In this backend change, Masked Loads get lowered to PTX with `#pragma
"used_bytes_mask" [mask];`
(https://docs.nvidia.com/cuda/parallel-thread-execution/#pragma-strings-used-bytes-mask).
And Masked Stores get lowered to PTX using the new sink symbol syntax
(https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st).
# TTI Changes
TTI changes are needed because NVPTX only supports masked loads/stores
with _constant_ masks. `ScalarizeMaskedMemIntrin.cpp` is adjusted to
check that the mask is constant and pass that result into the TTI check.
Behavior shouldn't change for non-NVPTX targets, which do not care
whether the mask is variable or constant when determining legality, but
all TTI files that implement these API need to be updated.
# Masked store lowering implementation details
If the masked stores make it to the NVPTX backend without being
scalarized, they are handled by the following:
* `NVPTXISelLowering.cpp` - Sets up a custom operation action and
handles it in lowerMSTORE. Similar handling to normal store vectors,
except we read the mask and place a sentinel register `$noreg` in each
position where the mask reads as false.
For example,
```
t10: v8i1 = BUILD_VECTOR Constant:i1<-1>, Constant:i1<0>, Constant:i1<0>, Constant:i1<-1>, Constant:i1<-1>, Constant:i1<0>, Constant:i1<0>, Constant:i1<-1>
t11: ch = masked_store<(store unknown-size into %ir.lsr.iv28, align 32, addrspace 1)> t5:1, t5, t7, undef:i64, t10
->
STV_i32_v8 killed %13:int32regs, $noreg, $noreg, killed %16:int32regs, killed %17:int32regs, $noreg, $noreg, killed %20:int32regs, 0, 0, 1, 8, 0, 32, %4:int64regs, 0, debug-location !18 :: (store unknown-size into %ir.lsr.iv28, align 32, addrspace 1);
```
* `NVPTXInstInfo.td` - changes the definition of store vectors to allow
for a mix of sink symbols and registers.
* `NVPXInstPrinter.h/.cpp` - Handles the `$noreg` case by printing "_".
# Masked load lowering implementation details
Masked loads are routed to normal PTX loads, with one difference: a
`#pragma "used_bytes_mask"` is emitted before the load instruction
(https://docs.nvidia.com/cuda/parallel-thread-execution/#pragma-strings-used-bytes-mask).
To accomplish this, a new operand is added to every NVPTXISD Load type
representing this mask.
* `NVPTXISelLowering.h/.cpp` - Masked loads are converted into normal
NVPTXISD loads with a mask operand in two ways. 1) In type legalization
through replaceLoadVector, which is the normal path, and 2) through
LowerMLOAD, to handle the legal vector types
(v2f16/v2bf16/v2i16/v4i8/v2f32) that will not be type legalized. Both
share the same convertMLOADToLoadWithUsedBytesMask helper. Both default
this operand to UINT32_MAX, representing all bytes on. For the latter,
we need a new `NVPTXISD::MLoadV1` type to represent that edge case
because we cannot put the used bytes mask operand on a generic
LoadSDNode.
* `NVPTXISelDAGToDAG.cpp` - Extract used bytes mask from loads, add them
to created machine instructions.
* `NVPTXInstPrinter.h/.cpp` - Print the pragma when the used bytes mask
isn't all ones.
* `NVPTXForwardParams.cpp`, `NVPTXReplaceImageHandles.cpp` - Update
manual indexing of load operands to account for new operand.
* `NVPTXInsrtInfo.td`, `NVPTXIntrinsics.td` - Add the used bytes mask to
the MI definitions.
* `NVPTXTagInvariantLoads.cpp` - Ensure that masked loads also get
tagged as invariant.
Some generic changes that are needed:
* `LegalizeVectorTypes.cpp` - Ensure flags are preserved when splitting
masked loads.
* `SelectionDAGBuilder.cpp` - Preserve `MD_invariant_load` on masked
load SDNode creation
Added:
llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
Modified:
llvm/include/llvm/Analysis/TargetTransformInfo.h
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
llvm/lib/Analysis/TargetTransformInfo.cpp
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
llvm/lib/Target/ARM/ARMTargetTransformInfo.h
llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXISelLowering.h
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp
llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h
llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
llvm/lib/Target/VE/VETargetTransformInfo.h
llvm/lib/Target/X86/X86TargetTransformInfo.cpp
llvm/lib/Target/X86/X86TargetTransformInfo.h
llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index bd4c901e9bc82..22cff2035eb0b 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -842,12 +842,20 @@ class TargetTransformInfo {
LLVM_ABI AddressingModeKind
getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
+ /// Some targets only support masked load/store with a constant mask.
+ enum MaskKind {
+ VariableOrConstantMask,
+ ConstantMask,
+ };
+
/// Return true if the target supports masked store.
- LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const;
+ LLVM_ABI bool
+ isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace,
+ MaskKind MaskKind = VariableOrConstantMask) const;
/// Return true if the target supports masked load.
- LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned AddressSpace) const;
+ LLVM_ABI bool
+ isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+ MaskKind MaskKind = VariableOrConstantMask) const;
/// Return true if the target supports nontemporal store.
LLVM_ABI bool isLegalNTStore(Type *DataType, Align Alignment) const;
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 580b219ddbe53..4954c0d90a1e1 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,12 +309,14 @@ class TargetTransformInfoImplBase {
}
virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace,
+ TTI::MaskKind MaskKind) const {
return false;
}
virtual bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace,
+ TTI::MaskKind MaskKind) const {
return false;
}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 46f90b4cec7c9..9768202a9ba26 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -468,13 +468,17 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
}
bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
- return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
+ unsigned AddressSpace,
+ TTI::MaskKind MaskKind) const {
+ return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace,
+ MaskKind);
}
bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
- return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace);
+ unsigned AddressSpace,
+ TTI::MaskKind MaskKind) const {
+ return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace,
+ MaskKind);
}
bool TargetTransformInfo::isLegalNTStore(Type *DataType,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 24a18e181ba80..4274e951446b8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -2465,6 +2465,7 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
SDValue PassThru = MLD->getPassThru();
Align Alignment = MLD->getBaseAlign();
ISD::LoadExtType ExtType = MLD->getExtensionType();
+ MachineMemOperand::Flags MMOFlags = MLD->getMemOperand()->getFlags();
// Split Mask operand
SDValue MaskLo, MaskHi;
@@ -2490,9 +2491,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, dl);
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
- MLD->getPointerInfo(), MachineMemOperand::MOLoad,
- LocationSize::beforeOrAfterPointer(), Alignment, MLD->getAAInfo(),
- MLD->getRanges());
+ MLD->getPointerInfo(), MMOFlags, LocationSize::beforeOrAfterPointer(),
+ Alignment, MLD->getAAInfo(), MLD->getRanges());
Lo = DAG.getMaskedLoad(LoVT, dl, Ch, Ptr, Offset, MaskLo, PassThruLo, LoMemVT,
MMO, MLD->getAddressingMode(), ExtType,
@@ -2515,8 +2515,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
LoMemVT.getStoreSize().getFixedValue());
MMO = DAG.getMachineFunction().getMachineMemOperand(
- MPI, MachineMemOperand::MOLoad, LocationSize::beforeOrAfterPointer(),
- Alignment, MLD->getAAInfo(), MLD->getRanges());
+ MPI, MMOFlags, LocationSize::beforeOrAfterPointer(), Alignment,
+ MLD->getAAInfo(), MLD->getRanges());
Hi = DAG.getMaskedLoad(HiVT, dl, Ch, Ptr, Offset, MaskHi, PassThruHi,
HiMemVT, MMO, MLD->getAddressingMode(), ExtType,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 985a54ca83256..88b35582a9f7d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -5063,6 +5063,8 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
auto MMOFlags = MachineMemOperand::MOLoad;
if (I.hasMetadata(LLVMContext::MD_nontemporal))
MMOFlags |= MachineMemOperand::MONonTemporal;
+ if (I.hasMetadata(LLVMContext::MD_invariant_load))
+ MMOFlags |= MachineMemOperand::MOInvariant;
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
MachinePointerInfo(PtrOperand), MMOFlags,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 6cc4987428567..52fc28a98449b 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -323,12 +323,14 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
}
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/,
+ TTI::MaskKind /*MaskKind*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/,
+ TTI::MaskKind /*MaskKind*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index d12b802fe234f..fdb0ec40cb41f 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1125,7 +1125,8 @@ bool ARMTTIImpl::isProfitableLSRChainElement(Instruction *I) const {
}
bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
- unsigned /*AddressSpace*/) const {
+ unsigned /*AddressSpace*/,
+ TTI::MaskKind /*MaskKind*/) const {
if (!EnableMaskedLoadStores || !ST->hasMVEIntegerOps())
return false;
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 919a6fc9fd0b0..30f2151b41239 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -186,12 +186,16 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
bool isProfitableLSRChainElement(Instruction *I) const override;
- bool isLegalMaskedLoad(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const override;
-
- bool isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const override {
- return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
+ bool
+ isLegalMaskedLoad(Type *DataTy, Align Alignment, unsigned AddressSpace,
+ TTI::MaskKind MaskKind =
+ TTI::MaskKind::VariableOrConstantMask) const override;
+
+ bool
+ isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
+ TTI::MaskKind MaskKind =
+ TTI::MaskKind::VariableOrConstantMask) const override {
+ return isLegalMaskedLoad(DataTy, Alignment, AddressSpace, MaskKind);
}
bool forceScalarizeMaskedGather(VectorType *VTy,
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 8f3f0cc8abb01..3f84cbb6555ed 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -343,14 +343,16 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
}
bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
- unsigned /*AddressSpace*/) const {
+ unsigned /*AddressSpace*/,
+ TTI::MaskKind /*MaskKind*/) const {
// This function is called from scalarize-masked-mem-intrin, which runs
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
}
bool HexagonTTIImpl::isLegalMaskedLoad(Type *DataType, Align /*Alignment*/,
- unsigned /*AddressSpace*/) const {
+ unsigned /*AddressSpace*/,
+ TTI::MaskKind /*MaskKind*/) const {
// This function is called from scalarize-masked-mem-intrin, which runs
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index e95b5a10b76a7..67388984bb3e3 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -165,9 +165,10 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
- bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
+ unsigned AddressSpace,
+ TTI::MaskKind MaskKind) const override;
+ bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+ TTI::MaskKind MaskKind) const override;
bool isLegalMaskedGather(Type *Ty, Align Alignment) const override;
bool isLegalMaskedScatter(Type *Ty, Align Alignment) const override;
bool forceScalarizeMaskedGather(VectorType *VTy,
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 77913f27838e2..5ff5fa36ac467 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -395,6 +395,25 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
}
}
+void NVPTXInstPrinter::printUsedBytesMaskPragma(const MCInst *MI, int OpNum,
+ raw_ostream &O) {
+ auto &Op = MI->getOperand(OpNum);
+ assert(Op.isImm() && "Invalid operand");
+ uint32_t Imm = (uint32_t)Op.getImm();
+ if (Imm != UINT32_MAX) {
+ O << ".pragma \"used_bytes_mask " << format_hex(Imm, 1) << "\";\n\t";
+ }
+}
+
+void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
+ raw_ostream &O) {
+ const MCOperand &Op = MI->getOperand(OpNum);
+ if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
+ O << "_";
+ else
+ printOperand(MI, OpNum, O);
+}
+
void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
raw_ostream &O) {
int64_t Imm = MI->getOperand(OpNum).getImm();
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 92155b01464e8..3d172441adfcc 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,6 +46,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
StringRef Modifier = {});
void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
+ void printUsedBytesMaskPragma(const MCInst *MI, int OpNum, raw_ostream &O);
+ void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O);
void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
index a3496090def3c..c8b53571c1e59 100644
--- a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
@@ -96,7 +96,7 @@ static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
const MachineOperand *ParamSymbol = Mov.uses().begin();
assert(ParamSymbol->isSymbol());
- constexpr unsigned LDInstBasePtrOpIdx = 5;
+ constexpr unsigned LDInstBasePtrOpIdx = 6;
constexpr unsigned LDInstAddrSpaceOpIdx = 2;
for (auto *LI : LoadInsts) {
(LI->uses().begin() + LDInstBasePtrOpIdx)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 996d653940118..0e1125ab8d8b3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -105,6 +105,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
switch (N->getOpcode()) {
case ISD::LOAD:
case ISD::ATOMIC_LOAD:
+ case NVPTXISD::MLoad:
if (tryLoad(N))
return;
break;
@@ -1132,6 +1133,19 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
? NVPTX::PTXLdStInstCode::Signed
: NVPTX::PTXLdStInstCode::Untyped;
+ uint32_t UsedBytesMask;
+ switch (N->getOpcode()) {
+ case ISD::LOAD:
+ case ISD::ATOMIC_LOAD:
+ UsedBytesMask = UINT32_MAX;
+ break;
+ case NVPTXISD::MLoad:
+ UsedBytesMask = N->getConstantOperandVal(3);
+ break;
+ default:
+ llvm_unreachable("Unexpected opcode");
+ }
+
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
FromTypeWidth <= 128 && "Invalid width for load");
@@ -1142,6 +1156,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
getI32Imm(CodeAddrSpace, DL),
getI32Imm(FromType, DL),
getI32Imm(FromTypeWidth, DL),
+ getI32Imm(UsedBytesMask, DL),
Base,
Offset,
Chain};
@@ -1196,14 +1211,14 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
// type is integer
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
// Read at least 8 bits (predicates are stored as 8-bit values)
- // The last operand holds the original LoadSDNode::getExtensionType() value
- const unsigned ExtensionType =
- N->getConstantOperandVal(N->getNumOperands() - 1);
+ // Get the original LoadSDNode::getExtensionType() value
+ const unsigned ExtensionType = N->getConstantOperandVal(4);
const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
? NVPTX::PTXLdStInstCode::Signed
: NVPTX::PTXLdStInstCode::Untyped;
const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
+ const uint32_t UsedBytesMask = N->getConstantOperandVal(3);
assert(!(EltVT.isVector() && ExtensionType != ISD::NON_EXTLOAD));
@@ -1213,6 +1228,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
getI32Imm(CodeAddrSpace, DL),
getI32Imm(FromType, DL),
getI32Imm(FromTypeWidth, DL),
+ getI32Imm(UsedBytesMask, DL),
Base,
Offset,
Chain};
@@ -1250,10 +1266,13 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
SDLoc DL(LD);
unsigned ExtensionType;
+ uint32_t UsedBytesMask;
if (const auto *Load = dyn_cast<LoadSDNode>(LD)) {
ExtensionType = Load->getExtensionType();
+ UsedBytesMask = UINT32_MAX;
} else {
- ExtensionType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
+ ExtensionType = LD->getConstantOperandVal(4);
+ UsedBytesMask = LD->getConstantOperandVal(3);
}
const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
? NVPTX::PTXLdStInstCode::Signed
@@ -1265,8 +1284,12 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
ExtensionType != ISD::NON_EXTLOAD));
const auto [Base, Offset] = selectADDR(LD->getOperand(1), CurDAG);
- SDValue Ops[] = {getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), Base,
- Offset, LD->getChain()};
+ SDValue Ops[] = {getI32Imm(FromType, DL),
+ getI32Imm(FromTypeWidth, DL),
+ getI32Imm(UsedBytesMask, DL),
+ Base,
+ Offset,
+ LD->getChain()};
const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
std::optional<unsigned> Opcode;
@@ -1277,6 +1300,10 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i16,
NVPTX::LD_GLOBAL_NC_i32, NVPTX::LD_GLOBAL_NC_i64);
break;
+ case NVPTXISD::MLoad:
+ Opcode = pickOpcodeForVT(TargetVT, std::nullopt, NVPTX::LD_GLOBAL_NC_i32,
+ NVPTX::LD_GLOBAL_NC_i64);
+ break;
case NVPTXISD::LoadV2:
Opcode =
pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v2i16,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 1773958520f04..f3a3bc785d997 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -770,7 +770,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
Custom);
for (MVT VT : MVT::fixedlen_vector_valuetypes())
if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
- setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
+ setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT,
+ Custom);
// Custom legalization for LDU intrinsics.
// TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -3133,6 +3134,86 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
return Or;
}
+static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
+ SDNode *N = Op.getNode();
+
+ SDValue Chain = N->getOperand(0);
+ SDValue Val = N->getOperand(1);
+ SDValue BasePtr = N->getOperand(2);
+ SDValue Offset = N->getOperand(3);
+ SDValue Mask = N->getOperand(4);
+
+ SDLoc DL(N);
+ EVT ValVT = Val.getValueType();
+ MemSDNode *MemSD = cast<MemSDNode>(N);
+ assert(ValVT.isVector() && "Masked vector store must have vector type");
+ assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
+ "Unexpected alignment for masked store");
+
+ unsigned Opcode = 0;
+ switch (ValVT.getSimpleVT().SimpleTy) {
+ default:
+ llvm_unreachable("Unexpected masked vector store type");
+ case MVT::v4i64:
+ case MVT::v4f64: {
+ Opcode = NVPTXISD::StoreV4;
+ break;
+ }
+ case MVT::v8i32:
+ case MVT::v8f32: {
+ Opcode = NVPTXISD::StoreV8;
+ break;
+ }
+ }
+
+ SmallVector<SDValue, 8> Ops;
+
+ // Construct the new SDNode. First operand is the chain.
+ Ops.push_back(Chain);
+
+ // The next N operands are the values to store. Encode the mask into the
+ // values using the sentinel register 0 to represent a masked-off element.
+ assert(Mask.getValueType().isVector() &&
+ Mask.getValueType().getVectorElementType() == MVT::i1 &&
+ "Mask must be a vector of i1");
+ assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
+ "Mask expected to be a BUILD_VECTOR");
+ assert(Mask.getValueType().getVectorNumElements() ==
+ ValVT.getVectorNumElements() &&
+ "Mask size must be the same as the vector size");
+ for (auto [I, Op] : enumerate(Mask->ops())) {
+ // Mask elements must be constants.
+ if (Op.getNode()->getAsZExtVal() == 0) {
+ // Append a sentinel register 0 to the Ops vector to represent a masked
+ // off element, this will be handled in tablegen
+ Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
+ ValVT.getVectorElementType()));
+ } else {
+ // Extract the element from the vector to store
+ SDValue ExtVal =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(),
+ Val, DAG.getIntPtrConstant(I, DL));
+ Ops.push_back(ExtVal);
+ }
+ }
+
+ // Next, the pointer operand.
+ Ops.push_back(BasePtr);
+
+ // Finally, the offset operand. We expect this to always be undef, and it will
+ // be ignored in lowering, but to mirror the handling of the other vector
+ // store instructions we include it in the new SDNode.
+ assert(Offset.getOpcode() == ISD::UNDEF &&
+ "Offset operand expected to be undef");
+ Ops.push_back(Offset);
+
+ SDValue NewSt =
+ DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+ MemSD->getMemoryVT(), MemSD->getMemOperand());
+
+ return NewSt;
+}
+
SDValue
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -3169,8 +3250,16 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerVECREDUCE(Op, DAG);
case ISD::STORE:
return LowerSTORE(Op, DAG);
+ case ISD::MSTORE: {
+ assert(STI.has256BitVectorLoadStore(
+ cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
+ "Masked store vector not supported on subtarget.");
+ return lowerMSTORE(Op, DAG);
+ }
case ISD::LOAD:
return LowerLOAD(Op, DAG);
+ case ISD::MLOAD:
+ return LowerMLOAD(Op, DAG);
case ISD::SHL_PARTS:
return LowerShiftLeftParts(Op, DAG);
case ISD::SRA_PARTS:
@@ -3363,10 +3452,56 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
MachinePointerInfo(SV));
}
+static std::pair<MemSDNode *, uint32_t>
+convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
+ SDValue Chain = N->getOperand(0);
+ SDValue BasePtr = N->getOperand(1);
+ SDValue Mask = N->getOperand(3);
+ SDValue Passthru = N->getOperand(4);
+
+ SDLoc DL(N);
+ EVT ResVT = N->getValueType(0);
+ assert(ResVT.isVector() && "Masked vector load must have vector type");
+ // While we only expect poison passthru vectors as an input to the backend,
+ // when the legalization framework splits a poison vector in half, it creates
+ // two undef vectors, so we can technically expect those too.
+ assert((Passthru.getOpcode() == ISD::POISON ||
+ Passthru.getOpcode() == ISD::UNDEF) &&
+ "Passthru operand expected to be poison or undef");
+
+ // Extract the mask and convert it to a uint32_t representing the used bytes
+ // of the entire vector load
+ uint32_t UsedBytesMask = 0;
+ uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
+ assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
+ uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
+ uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
+
+ for (SDValue Op : reverse(Mask->ops())) {
+ // We technically only want to do this shift for every
+ // iteration *but* the first, but in the first iteration UsedBytesMask is 0,
+ // so this shift is a no-op.
+ UsedBytesMask <<= ElementSizeInBytes;
+
+ // Mask elements must be constants.
+ if (Op->getAsZExtVal() != 0)
+ UsedBytesMask |= ElementMask;
+ }
+
+ assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
+ "Unexpected masked load with elements masked all on or all off");
+
+ // Create a new load sd node to be handled normally by ReplaceLoadVector.
+ MemSDNode *NewLD = cast<MemSDNode>(
+ DAG.getLoad(ResVT, DL, Chain, BasePtr, N->getMemOperand()).getNode());
+
+ return {NewLD, UsedBytesMask};
+}
+
/// replaceLoadVector - Convert vector loads into multi-output scalar loads.
static std::optional<std::pair<SDValue, SDValue>>
replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
- LoadSDNode *LD = cast<LoadSDNode>(N);
+ MemSDNode *LD = cast<MemSDNode>(N);
const EVT ResVT = LD->getValueType(0);
const EVT MemVT = LD->getMemoryVT();
@@ -3393,6 +3528,11 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
return std::nullopt;
}
+ // If we have a masked load, convert it to a normal load now
+ std::optional<uint32_t> UsedBytesMask = std::nullopt;
+ if (LD->getOpcode() == ISD::MLOAD)
+ std::tie(LD, UsedBytesMask) = convertMLOADToLoadWithUsedBytesMask(LD, DAG);
+
// Since LoadV2 is a target node, we cannot rely on DAG type legalization.
// Therefore, we must ensure the type is legal. For i1 and i8, we set the
// loaded type to i16 and propagate the "real" type as the memory type.
@@ -3421,9 +3561,13 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
// Copy regular operands
SmallVector<SDValue, 8> OtherOps(LD->ops());
+ OtherOps.push_back(
+ DAG.getConstant(UsedBytesMask.value_or(UINT32_MAX), DL, MVT::i32));
+
// The select routine does not have access to the LoadSDNode instance, so
// pass along the extension information
- OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
+ OtherOps.push_back(
+ DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT,
LD->getMemOperand());
@@ -3511,6 +3655,42 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
llvm_unreachable("Unexpected custom lowering for load");
}
+SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
+ // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
+ // masked loads of these types and have to handle them here.
+ // v2f32 also needs to be handled here if the subtarget has f32x2
+ // instructions, making it legal.
+ //
+ // Note: misaligned masked loads should never reach this point
+ // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
+ // will validate alignment. Therefore, we do not need to special case handle
+ // them here.
+ EVT VT = Op.getValueType();
+ if (NVPTX::isPackedVectorTy(VT)) {
+ auto Result =
+ convertMLOADToLoadWithUsedBytesMask(cast<MemSDNode>(Op.getNode()), DAG);
+ MemSDNode *LD = std::get<0>(Result);
+ uint32_t UsedBytesMask = std::get<1>(Result);
+
+ SDLoc DL(LD);
+
+ // Copy regular operands
+ SmallVector<SDValue, 8> OtherOps(LD->ops());
+
+ OtherOps.push_back(DAG.getConstant(UsedBytesMask, DL, MVT::i32));
+
+ // We currently are not lowering extending loads, but pass the extension
+ // type anyway as later handling expects it.
+ OtherOps.push_back(
+ DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
+ SDValue NewLD =
+ DAG.getMemIntrinsicNode(NVPTXISD::MLoad, DL, LD->getVTList(), OtherOps,
+ LD->getMemoryVT(), LD->getMemOperand());
+ return NewLD;
+ }
+ return SDValue();
+}
+
static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG,
const NVPTXSubtarget &STI) {
MemSDNode *N = cast<MemSDNode>(Op.getNode());
@@ -5419,6 +5599,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
// ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
// here.
Opcode = NVPTXISD::LoadV2;
+ // append a "full" used bytes mask operand right before the extension type
+ // operand, signifying that all bytes are used.
+ Operands.push_back(DCI.DAG.getConstant(UINT32_MAX, DL, MVT::i32));
Operands.push_back(DCI.DAG.getIntPtrConstant(
cast<LoadSDNode>(LD)->getExtensionType(), DL));
break;
@@ -5427,9 +5610,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
Opcode = NVPTXISD::LoadV4;
break;
case NVPTXISD::LoadV4:
- // V8 is only supported for f32. Don't forget, we're not changing the load
- // size here. This is already a 256-bit load.
- if (ElementVT != MVT::v2f32)
+ // V8 is only supported for f32/i32. Don't forget, we're not changing the
+ // load size here. This is already a 256-bit load.
+ if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
return SDValue();
OldNumOutputs = 4;
Opcode = NVPTXISD::LoadV8;
@@ -5504,9 +5687,9 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
Opcode = NVPTXISD::StoreV4;
break;
case NVPTXISD::StoreV4:
- // V8 is only supported for f32. Don't forget, we're not changing the store
- // size here. This is already a 256-bit store.
- if (ElementVT != MVT::v2f32)
+ // V8 is only supported for f32/i32. Don't forget, we're not changing the
+ // store size here. This is already a 256-bit store.
+ if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
return SDValue();
Opcode = NVPTXISD::StoreV8;
break;
@@ -6657,6 +6840,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
ReplaceBITCAST(N, DAG, Results);
return;
case ISD::LOAD:
+ case ISD::MLOAD:
replaceLoadVector(N, DAG, Results, STI);
return;
case ISD::INTRINSIC_W_CHAIN:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index d71a86fd463f6..dd8e49de7aa6a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -235,6 +235,7 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 04e2dd435cdf0..feefaf9a21e5b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1588,6 +1588,14 @@ def ADDR : Operand<pAny> {
let MIOperandInfo = (ops ADDR_base, i32imm);
}
+def UsedBytesMask : Operand<i32> {
+ let PrintMethod = "printUsedBytesMaskPragma";
+}
+
+def RegOrSink : Operand<Any> {
+ let PrintMethod = "printRegisterOrSinkSymbol";
+}
+
def AtomicCode : Operand<i32> {
let PrintMethod = "printAtomicCode";
}
@@ -1832,8 +1840,10 @@ def Callseq_End :
class LD<NVPTXRegClass regclass>
: NVPTXInst<
(outs regclass:$dst),
- (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, AtomicCode:$Sign,
- i32imm:$fromWidth, ADDR:$addr),
+ (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
+ AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$addr),
+ "${usedBytes}"
"ld${sem:sem}${scope:scope}${addsp:addsp}.${Sign:sign}$fromWidth "
"\t$dst, [$addr];">;
@@ -1865,21 +1875,27 @@ multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
def _v2 : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2),
(ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
- AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
+ AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$addr),
+ "${usedBytes}"
"ld${sem:sem}${scope:scope}${addsp:addsp}.v2.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2}}, [$addr];">;
def _v4 : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4),
(ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
- AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
+ AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$addr),
+ "${usedBytes}"
"ld${sem:sem}${scope:scope}${addsp:addsp}.v4.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr];">;
if support_v8 then
def _v8 : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
- (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, AtomicCode:$Sign,
- i32imm:$fromWidth, ADDR:$addr),
+ (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
+ AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$addr),
+ "${usedBytes}"
"ld${sem:sem}${scope:scope}${addsp:addsp}.v8.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, "
"[$addr];">;
@@ -1900,7 +1916,7 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
"\t[$addr], {{$src1, $src2}};">;
def _v4 : NVPTXInst<
(outs),
- (ins O:$src1, O:$src2, O:$src3, O:$src4,
+ (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth "
@@ -1908,8 +1924,8 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
if support_v8 then
def _v8 : NVPTXInst<
(outs),
- (ins O:$src1, O:$src2, O:$src3, O:$src4,
- O:$src5, O:$src6, O:$src7, O:$src8,
+ (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
+ RegOrSink:$src5, RegOrSink:$src6, RegOrSink:$src7, RegOrSink:$src8,
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth "
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 8501d4d7bb86f..d18c7e20df038 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2552,7 +2552,10 @@ def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<B32>;
// during the lifetime of the kernel.
class LDG_G<NVPTXRegClass regclass>
- : NVPTXInst<(outs regclass:$result), (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+ : NVPTXInst<(outs regclass:$result),
+ (ins AtomicCode:$Sign, i32imm:$fromWidth,
+ UsedBytesMask:$usedBytes, ADDR:$src),
+ "${usedBytes}"
"ld.global.nc.${Sign:sign}$fromWidth \t$result, [$src];">;
def LD_GLOBAL_NC_i16 : LDG_G<B16>;
@@ -2564,19 +2567,25 @@ def LD_GLOBAL_NC_i64 : LDG_G<B64>;
// Elementized vector ldg
class VLDG_G_ELE_V2<NVPTXRegClass regclass> :
NVPTXInst<(outs regclass:$dst1, regclass:$dst2),
- (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+ (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$src),
+ "${usedBytes}"
"ld.global.nc.v2.${Sign:sign}$fromWidth \t{{$dst1, $dst2}}, [$src];">;
class VLDG_G_ELE_V4<NVPTXRegClass regclass> :
NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4),
- (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+ (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$src),
+ "${usedBytes}"
"ld.global.nc.v4.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];">;
class VLDG_G_ELE_V8<NVPTXRegClass regclass> :
NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
- (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+ (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$src),
+ "${usedBytes}"
"ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];">;
// FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
diff --git a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
index 320c0fb6950a7..4bbf49f93f43b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
@@ -1808,8 +1808,8 @@ bool NVPTXReplaceImageHandles::replaceImageHandle(MachineOperand &Op,
// For CUDA, we preserve the param loads coming from function arguments
return false;
- assert(TexHandleDef.getOperand(6).isSymbol() && "Load is not a symbol!");
- StringRef Sym = TexHandleDef.getOperand(6).getSymbolName();
+ assert(TexHandleDef.getOperand(7).isSymbol() && "Load is not a symbol!");
+ StringRef Sym = TexHandleDef.getOperand(7).getSymbolName();
InstrsToRemove.insert(&TexHandleDef);
Op.ChangeToES(Sym.data());
MFI->getImageHandleSymbolIndex(Sym);
diff --git a/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp
index e8ea1ad6c404d..710d063e75725 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp
@@ -30,6 +30,7 @@ const char *NVPTXSelectionDAGInfo::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::LoadV2)
MAKE_CASE(NVPTXISD::LoadV4)
MAKE_CASE(NVPTXISD::LoadV8)
+ MAKE_CASE(NVPTXISD::MLoad)
MAKE_CASE(NVPTXISD::LDUV2)
MAKE_CASE(NVPTXISD::LDUV4)
MAKE_CASE(NVPTXISD::StoreV2)
diff --git a/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h b/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h
index 07c130baeaa4f..9dd0a1eaa5856 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h
@@ -36,6 +36,7 @@ enum NodeType : unsigned {
LoadV2,
LoadV4,
LoadV8,
+ MLoad,
LDUV2, // LDU.v2
LDUV4, // LDU.v4
StoreV2,
diff --git a/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp b/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
index a4aff44ac04f6..f1774a7c5572e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
@@ -27,13 +27,14 @@
using namespace llvm;
-static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) {
+static bool isInvariantLoad(const Instruction *I, const Value *Ptr,
+ const bool IsKernelFn) {
// Don't bother with non-global loads
- if (LI->getPointerAddressSpace() != NVPTXAS::ADDRESS_SPACE_GLOBAL)
+ if (Ptr->getType()->getPointerAddressSpace() != NVPTXAS::ADDRESS_SPACE_GLOBAL)
return false;
// If the load is already marked as invariant, we don't need to do anything
- if (LI->getMetadata(LLVMContext::MD_invariant_load))
+ if (I->getMetadata(LLVMContext::MD_invariant_load))
return false;
// We use getUnderlyingObjects() here instead of getUnderlyingObject()
@@ -41,7 +42,7 @@ static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) {
// not. We need to look through phi nodes to handle pointer induction
// variables.
SmallVector<const Value *, 8> Objs;
- getUnderlyingObjects(LI->getPointerOperand(), Objs);
+ getUnderlyingObjects(Ptr, Objs);
return all_of(Objs, [&](const Value *V) {
if (const auto *A = dyn_cast<const Argument>(V))
@@ -53,9 +54,9 @@ static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) {
});
}
-static void markLoadsAsInvariant(LoadInst *LI) {
- LI->setMetadata(LLVMContext::MD_invariant_load,
- MDNode::get(LI->getContext(), {}));
+static void markLoadsAsInvariant(Instruction *I) {
+ I->setMetadata(LLVMContext::MD_invariant_load,
+ MDNode::get(I->getContext(), {}));
}
static bool tagInvariantLoads(Function &F) {
@@ -63,12 +64,17 @@ static bool tagInvariantLoads(Function &F) {
bool Changed = false;
for (auto &I : instructions(F)) {
- if (auto *LI = dyn_cast<LoadInst>(&I)) {
- if (isInvariantLoad(LI, IsKernelFn)) {
+ if (auto *LI = dyn_cast<LoadInst>(&I))
+ if (isInvariantLoad(LI, LI->getPointerOperand(), IsKernelFn)) {
markLoadsAsInvariant(LI);
Changed = true;
}
- }
+ if (auto *II = dyn_cast<IntrinsicInst>(&I))
+ if (II->getIntrinsicID() == Intrinsic::masked_load &&
+ isInvariantLoad(II, II->getOperand(0), IsKernelFn)) {
+ markLoadsAsInvariant(II);
+ Changed = true;
+ }
}
return Changed;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 64593e6439184..5d5553c573b0f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -592,6 +592,45 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
return nullptr;
}
+bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
+ unsigned AddrSpace,
+ TTI::MaskKind MaskKind) const {
+ if (MaskKind != TTI::MaskKind::ConstantMask)
+ return false;
+
+ // We currently only support this feature for 256-bit vectors, so the
+ // alignment must be at least 32
+ if (Alignment < 32)
+ return false;
+
+ if (!ST->has256BitVectorLoadStore(AddrSpace))
+ return false;
+
+ auto *VTy = dyn_cast<FixedVectorType>(DataTy);
+ if (!VTy)
+ return false;
+
+ auto *ElemTy = VTy->getScalarType();
+ return (ElemTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
+ (ElemTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4);
+}
+
+bool NVPTXTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
+ unsigned /*AddrSpace*/,
+ TTI::MaskKind MaskKind) const {
+ if (MaskKind != TTI::MaskKind::ConstantMask)
+ return false;
+
+ if (Alignment < DL.getTypeStoreSize(DataTy))
+ return false;
+
+ // We do not support sub-byte element type masked loads.
+ auto *VTy = dyn_cast<FixedVectorType>(DataTy);
+ if (!VTy)
+ return false;
+ return VTy->getElementType()->getScalarSizeInBits() >= 8;
+}
+
unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
// 256 bit loads/stores are currently only supported for global address space
if (ST->has256BitVectorLoadStore(AddrSpace))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 78eb751cf3c2e..d7f4e1da4073b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,6 +181,12 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
Intrinsic::ID IID) const override;
+ bool isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddrSpace,
+ TTI::MaskKind MaskKind) const override;
+
+ bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddrSpace,
+ TTI::MaskKind MaskKind) const override;
+
unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 39c1173e2986c..484c4791390ac 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -285,11 +285,13 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
}
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/,
+ TTI::MaskKind /*MaskKind*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/,
+ TTI::MaskKind /*MaskKind*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 5c0ddca62c761..eed3832c9f1fb 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -134,12 +134,14 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
}
// Load & Store {
- bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ bool
+ isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned /*AddressSpace*/,
+ TargetTransformInfo::MaskKind /*MaskKind*/) const override {
return isVectorLaneType(*getLaneType(DataType));
}
- bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ bool isLegalMaskedStore(
+ Type *DataType, Align Alignment, unsigned /*AddressSpace*/,
+ TargetTransformInfo::MaskKind /*MaskKind*/) const override {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 6ae55f3623096..a9dc96b53d530 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6322,7 +6322,8 @@ static bool isLegalMaskedLoadStore(Type *ScalarTy, const X86Subtarget *ST) {
}
bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace,
+ TTI::MaskKind MaskKind) const {
Type *ScalarTy = DataTy->getScalarType();
// The backend can't handle a single element vector w/o CFCMOV.
@@ -6335,7 +6336,8 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
}
bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace,
+ TTI::MaskKind MaskKind) const {
Type *ScalarTy = DataTy->getScalarType();
// The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index f8866472bd9af..d6dea9427990b 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -267,10 +267,14 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
const TargetTransformInfo::LSRCost &C2) const override;
bool canMacroFuseCmp() const override;
- bool isLegalMaskedLoad(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
- bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
+ bool
+ isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+ TTI::MaskKind MaskKind =
+ TTI::MaskKind::VariableOrConstantMask) const override;
+ bool
+ isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace,
+ TTI::MaskKind MaskKind =
+ TTI::MaskKind::VariableOrConstantMask) const override;
bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
bool isLegalNTStore(Type *DataType, Align Alignment) const override;
bool isLegalBroadcastLoad(Type *ElementTy,
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 146e7d1047dd0..b7b08ae61ec52 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1123,7 +1123,10 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
if (TTI.isLegalMaskedLoad(
CI->getType(), CI->getParamAlign(0).valueOrOne(),
cast<PointerType>(CI->getArgOperand(0)->getType())
- ->getAddressSpace()))
+ ->getAddressSpace(),
+ isConstantIntVector(CI->getArgOperand(1))
+ ? TTI::MaskKind::ConstantMask
+ : TTI::MaskKind::VariableOrConstantMask))
return false;
scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
@@ -1132,7 +1135,10 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
CI->getArgOperand(0)->getType(),
CI->getParamAlign(1).valueOrOne(),
cast<PointerType>(CI->getArgOperand(1)->getType())
- ->getAddressSpace()))
+ ->getAddressSpace(),
+ isConstantIntVector(CI->getArgOperand(2))
+ ? TTI::MaskKind::ConstantMask
+ : TTI::MaskKind::VariableOrConstantMask))
return false;
scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
diff --git a/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir b/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
index e3b072549bc04..3158916a3195c 100644
--- a/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
+++ b/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
@@ -40,9 +40,9 @@ registers:
- { id: 7, class: b32 }
body: |
bb.0.entry:
- %0 = LD_i32 0, 0, 4, 2, 32, &test_param_0, 0
+ %0 = LD_i32 0, 0, 4, 2, 32, -1, &test_param_0, 0
%1 = CVT_f64_f32 %0, 0
- %2 = LD_i32 0, 0, 4, 0, 32, &test_param_1, 0
+ %2 = LD_i32 0, 0, 4, 0, 32, -1, &test_param_1, 0
; CHECK: %3:b64 = FADD_rnf64ri %1, double 3.250000e+00
%3 = FADD_rnf64ri %1, double 3.250000e+00
%4 = CVT_f32_f64 %3, 5
@@ -66,9 +66,9 @@ registers:
- { id: 7, class: b32 }
body: |
bb.0.entry:
- %0 = LD_i32 0, 0, 4, 2, 32, &test2_param_0, 0
+ %0 = LD_i32 0, 0, 4, 2, 32, -1, &test2_param_0, 0
%1 = CVT_f64_f32 %0, 0
- %2 = LD_i32 0, 0, 4, 0, 32, &test2_param_1, 0
+ %2 = LD_i32 0, 0, 4, 0, 32, -1, &test2_param_1, 0
; CHECK: %3:b64 = FADD_rnf64ri %1, double 0x7FF8000000000000
%3 = FADD_rnf64ri %1, double 0x7FF8000000000000
%4 = CVT_f32_f64 %3, 5
diff --git a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
index 3fac29f74125b..d219493d2b31b 100644
--- a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
+++ b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
@@ -346,19 +346,15 @@ define i32 @ld_global_v8i32(ptr addrspace(1) %ptr) {
; SM100-LABEL: ld_global_v8i32(
; SM100: {
; SM100-NEXT: .reg .b32 %r<16>;
-; SM100-NEXT: .reg .b64 %rd<6>;
+; SM100-NEXT: .reg .b64 %rd<2>;
; SM100-EMPTY:
; SM100-NEXT: // %bb.0:
; SM100-NEXT: ld.param.b64 %rd1, [ld_global_v8i32_param_0];
-; SM100-NEXT: ld.global.nc.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
-; SM100-NEXT: mov.b64 {%r1, %r2}, %rd5;
-; SM100-NEXT: mov.b64 {%r3, %r4}, %rd4;
-; SM100-NEXT: mov.b64 {%r5, %r6}, %rd3;
-; SM100-NEXT: mov.b64 {%r7, %r8}, %rd2;
-; SM100-NEXT: add.s32 %r9, %r7, %r8;
-; SM100-NEXT: add.s32 %r10, %r5, %r6;
-; SM100-NEXT: add.s32 %r11, %r3, %r4;
-; SM100-NEXT: add.s32 %r12, %r1, %r2;
+; SM100-NEXT: ld.global.nc.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: add.s32 %r9, %r1, %r2;
+; SM100-NEXT: add.s32 %r10, %r3, %r4;
+; SM100-NEXT: add.s32 %r11, %r5, %r6;
+; SM100-NEXT: add.s32 %r12, %r7, %r8;
; SM100-NEXT: add.s32 %r13, %r9, %r10;
; SM100-NEXT: add.s32 %r14, %r11, %r12;
; SM100-NEXT: add.s32 %r15, %r13, %r14;
diff --git a/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir b/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
index 0b2d85600a2ef..4be91dfc60c6a 100644
--- a/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
+++ b/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
@@ -26,10 +26,10 @@ body: |
; CHECK: bb.0.entry:
; CHECK-NEXT: successors: %bb.2(0x30000000), %bb.3(0x50000000)
; CHECK-NEXT: {{ $}}
- ; CHECK-NEXT: [[LD_i32_:%[0-9]+]]:b32 = LD_i32 0, 0, 101, 3, 32, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
- ; CHECK-NEXT: [[LD_i64_:%[0-9]+]]:b64 = LD_i64 0, 0, 101, 3, 64, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
+ ; CHECK-NEXT: [[LD_i32_:%[0-9]+]]:b32 = LD_i32 0, 0, 101, 3, 32, -1, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
+ ; CHECK-NEXT: [[LD_i64_:%[0-9]+]]:b64 = LD_i64 0, 0, 101, 3, 64, -1, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
; CHECK-NEXT: [[ADD64ri:%[0-9]+]]:b64 = nuw ADD64ri killed [[LD_i64_]], 2
- ; CHECK-NEXT: [[LD_i32_1:%[0-9]+]]:b32 = LD_i32 0, 0, 1, 3, 32, [[ADD64ri]], 0
+ ; CHECK-NEXT: [[LD_i32_1:%[0-9]+]]:b32 = LD_i32 0, 0, 1, 3, 32, -1, [[ADD64ri]], 0
; CHECK-NEXT: [[SETP_i32ri:%[0-9]+]]:b1 = SETP_i32ri [[LD_i32_]], 0, 0
; CHECK-NEXT: CBranch killed [[SETP_i32ri]], %bb.2
; CHECK-NEXT: {{ $}}
@@ -54,10 +54,10 @@ body: |
bb.0.entry:
successors: %bb.2(0x30000000), %bb.1(0x50000000)
- %5:b32 = LD_i32 0, 0, 101, 3, 32, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
- %6:b64 = LD_i64 0, 0, 101, 3, 64, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
+ %5:b32 = LD_i32 0, 0, 101, 3, 32, -1, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
+ %6:b64 = LD_i64 0, 0, 101, 3, 64, -1, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
%0:b64 = nuw ADD64ri killed %6, 2
- %1:b32 = LD_i32 0, 0, 1, 3, 32, %0, 0
+ %1:b32 = LD_i32 0, 0, 1, 3, 32, -1, %0, 0
%7:b1 = SETP_i32ri %5, 0, 0
CBranch killed %7, %bb.2
GOTO %bb.1
diff --git a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
new file mode 100644
index 0000000000000..8617dea310d6c
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
@@ -0,0 +1,366 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s -check-prefixes=CHECK,SM90
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK,SM100
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+
+; Different architectures are tested in this file for the following reasons:
+; - SM90 does not have 256-bit load/store instructions
+; - SM90 does not have masked store instructions
+; - SM90 does not support packed f32x2 instructions
+
+define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM90-NEXT: .pragma "used_bytes_mask 0xf000";
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: .pragma "used_bytes_mask 0xf0f";
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r5;
+; SM90-NEXT: st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT: st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_8xi32(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM100-NEXT: .pragma "used_bytes_mask 0xf0000f0f";
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT: ret;
+ %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 32 %a, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+; Masked stores are only supported for 32-bit element types,
+; while masked loads are supported for all element types.
+define void @global_16xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_16xi16(
+; SM90: {
+; SM90-NEXT: .reg .b16 %rs<7>;
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_16xi16_param_0];
+; SM90-NEXT: .pragma "used_bytes_mask 0xf000";
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: mov.b32 {%rs1, %rs2}, %r4;
+; SM90-NEXT: .pragma "used_bytes_mask 0xf0f";
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r7;
+; SM90-NEXT: mov.b32 {%rs5, %rs6}, %r5;
+; SM90-NEXT: ld.param.b64 %rd2, [global_16xi16_param_1];
+; SM90-NEXT: st.global.b16 [%rd2], %rs5;
+; SM90-NEXT: st.global.b16 [%rd2+2], %rs6;
+; SM90-NEXT: st.global.b16 [%rd2+8], %rs3;
+; SM90-NEXT: st.global.b16 [%rd2+10], %rs4;
+; SM90-NEXT: st.global.b16 [%rd2+28], %rs1;
+; SM90-NEXT: st.global.b16 [%rd2+30], %rs2;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_16xi16(
+; SM100: {
+; SM100-NEXT: .reg .b16 %rs<7>;
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_16xi16_param_0];
+; SM100-NEXT: .pragma "used_bytes_mask 0xf0000f0f";
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: mov.b32 {%rs1, %rs2}, %r8;
+; SM100-NEXT: mov.b32 {%rs3, %rs4}, %r3;
+; SM100-NEXT: mov.b32 {%rs5, %rs6}, %r1;
+; SM100-NEXT: ld.param.b64 %rd2, [global_16xi16_param_1];
+; SM100-NEXT: st.global.b16 [%rd2], %rs5;
+; SM100-NEXT: st.global.b16 [%rd2+2], %rs6;
+; SM100-NEXT: st.global.b16 [%rd2+8], %rs3;
+; SM100-NEXT: st.global.b16 [%rd2+10], %rs4;
+; SM100-NEXT: st.global.b16 [%rd2+28], %rs1;
+; SM100-NEXT: st.global.b16 [%rd2+30], %rs2;
+; SM100-NEXT: ret;
+ %a.load = tail call <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1) align 32 %a, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>, <16 x i16> poison)
+ tail call void @llvm.masked.store.v16i16.p1(<16 x i16> %a.load, ptr addrspace(1) align 32 %b, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>)
+ ret void
+}
+
+define void @global_8xi32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_8xi32_no_align(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<4>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_8xi32_no_align_param_0];
+; CHECK-NEXT: ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_8xi32_no_align_param_1];
+; CHECK-NEXT: ld.global.b32 %r2, [%rd1+8];
+; CHECK-NEXT: ld.global.b32 %r3, [%rd1+28];
+; CHECK-NEXT: st.global.b32 [%rd2], %r1;
+; CHECK-NEXT: st.global.b32 [%rd2+8], %r2;
+; CHECK-NEXT: st.global.b32 [%rd2+28], %r3;
+; CHECK-NEXT: ret;
+ %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 16 %a, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 16 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+
+define void @global_8xi32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32_invariant(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_invariant_param_0];
+; SM90-NEXT: .pragma "used_bytes_mask 0xf000";
+; SM90-NEXT: ld.global.nc.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: .pragma "used_bytes_mask 0xf0f";
+; SM90-NEXT: ld.global.nc.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_invariant_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r5;
+; SM90-NEXT: st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT: st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_8xi32_invariant(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_invariant_param_0];
+; SM100-NEXT: .pragma "used_bytes_mask 0xf0000f0f";
+; SM100-NEXT: ld.global.nc.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_invariant_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT: ret;
+ %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 32 %a, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison), !invariant.load !0
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @global_2xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xi16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_2xi16_param_0];
+; CHECK-NEXT: .pragma "used_bytes_mask 0x3";
+; CHECK-NEXT: ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_2xi16_param_1];
+; CHECK-NEXT: mov.b32 {%rs1, _}, %r1;
+; CHECK-NEXT: st.global.b16 [%rd2], %rs1;
+; CHECK-NEXT: ret;
+ %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 4 %a, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
+ tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> <i1 true, i1 false>)
+ ret void
+}
+
+define void @global_2xi16_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xi16_invariant(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_2xi16_invariant_param_0];
+; CHECK-NEXT: .pragma "used_bytes_mask 0x3";
+; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_2xi16_invariant_param_1];
+; CHECK-NEXT: mov.b32 {%rs1, _}, %r1;
+; CHECK-NEXT: st.global.b16 [%rd2], %rs1;
+; CHECK-NEXT: ret;
+ %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 4 %a, <2 x i1> <i1 true, i1 false>, <2 x i16> poison), !invariant.load !0
+ tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> <i1 true, i1 false>)
+ ret void
+}
+
+define void @global_2xi16_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xi16_no_align(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_2xi16_no_align_param_0];
+; CHECK-NEXT: ld.global.b16 %rs1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_2xi16_no_align_param_1];
+; CHECK-NEXT: st.global.b16 [%rd2], %rs1;
+; CHECK-NEXT: ret;
+ %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 2 %a, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
+ tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> <i1 true, i1 false>)
+ ret void
+}
+
+define void @global_4xi8(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_4xi8(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_4xi8_param_0];
+; CHECK-NEXT: .pragma "used_bytes_mask 0x5";
+; CHECK-NEXT: ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_4xi8_param_1];
+; CHECK-NEXT: st.global.b8 [%rd2], %r1;
+; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x7772U;
+; CHECK-NEXT: st.global.b8 [%rd2+2], %r2;
+; CHECK-NEXT: ret;
+ %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 4 %a, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
+ tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+define void @global_4xi8_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_4xi8_invariant(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_4xi8_invariant_param_0];
+; CHECK-NEXT: .pragma "used_bytes_mask 0x5";
+; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_4xi8_invariant_param_1];
+; CHECK-NEXT: st.global.b8 [%rd2], %r1;
+; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x7772U;
+; CHECK-NEXT: st.global.b8 [%rd2+2], %r2;
+; CHECK-NEXT: ret;
+ %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 4 %a, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison), !invariant.load !0
+ tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+define void @global_4xi8_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_4xi8_no_align(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_4xi8_no_align_param_0];
+; CHECK-NEXT: ld.global.b8 %rs1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_4xi8_no_align_param_1];
+; CHECK-NEXT: ld.global.b8 %rs2, [%rd1+2];
+; CHECK-NEXT: st.global.b8 [%rd2], %rs1;
+; CHECK-NEXT: st.global.b8 [%rd2+2], %rs2;
+; CHECK-NEXT: ret;
+ %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 2 %a, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
+ tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+; In sm100+, we pack 2xf32 loads into a single b64 load while lowering
+define void @global_2xf32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_2xf32(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<3>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_2xf32_param_0];
+; SM90-NEXT: .pragma "used_bytes_mask 0xf";
+; SM90-NEXT: ld.global.v2.b32 {%r1, %r2}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_2xf32_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r1;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_2xf32(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<2>;
+; SM100-NEXT: .reg .b64 %rd<4>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_2xf32_param_0];
+; SM100-NEXT: .pragma "used_bytes_mask 0xf";
+; SM100-NEXT: ld.global.b64 %rd2, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd3, [global_2xf32_param_1];
+; SM100-NEXT: mov.b64 {%r1, _}, %rd2;
+; SM100-NEXT: st.global.b32 [%rd3], %r1;
+; SM100-NEXT: ret;
+ %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 8 %a, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
+ tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> <i1 true, i1 false>)
+ ret void
+}
+
+define void @global_2xf32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_2xf32_invariant(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<3>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_2xf32_invariant_param_0];
+; SM90-NEXT: .pragma "used_bytes_mask 0xf";
+; SM90-NEXT: ld.global.nc.v2.b32 {%r1, %r2}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_2xf32_invariant_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r1;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_2xf32_invariant(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<2>;
+; SM100-NEXT: .reg .b64 %rd<4>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_2xf32_invariant_param_0];
+; SM100-NEXT: .pragma "used_bytes_mask 0xf";
+; SM100-NEXT: ld.global.nc.b64 %rd2, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd3, [global_2xf32_invariant_param_1];
+; SM100-NEXT: mov.b64 {%r1, _}, %rd2;
+; SM100-NEXT: st.global.b32 [%rd3], %r1;
+; SM100-NEXT: ret;
+ %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 8 %a, <2 x i1> <i1 true, i1 false>, <2 x float> poison), !invariant.load !0
+ tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> <i1 true, i1 false>)
+ ret void
+}
+
+define void @global_2xf32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xf32_no_align(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [global_2xf32_no_align_param_0];
+; CHECK-NEXT: ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [global_2xf32_no_align_param_1];
+; CHECK-NEXT: st.global.b32 [%rd2], %r1;
+; CHECK-NEXT: ret;
+ %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 4 %a, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
+ tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> <i1 true, i1 false>)
+ ret void
+}
+
+declare <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1), <8 x i1>, <8 x i32>)
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), <8 x i1>)
+declare <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1), <16 x i1>, <16 x i16>)
+declare void @llvm.masked.store.v16i16.p1(<16 x i16>, ptr addrspace(1), <16 x i1>)
+declare <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1), <2 x i1>, <2 x i16>)
+declare void @llvm.masked.store.v2i16.p1(<2 x i16>, ptr addrspace(1), <2 x i1>)
+declare <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1), <4 x i1>, <4 x i8>)
+declare void @llvm.masked.store.v4i8.p1(<4 x i8>, ptr addrspace(1), <4 x i1>)
+declare <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1), <2 x i1>, <2 x float>)
+declare void @llvm.masked.store.v2f32.p1(<2 x float>, ptr addrspace(1), <2 x i1>)
+!0 = !{}
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
new file mode 100644
index 0000000000000..9f23acaf93bc8
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; Confirm that a masked store with a variable mask is scalarized before lowering
+
+define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x i1> %mask) {
+; CHECK-LABEL: global_variable_mask(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<9>;
+; CHECK-NEXT: .reg .b16 %rs<9>;
+; CHECK-NEXT: .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b8 %rs1, [global_variable_mask_param_2+3];
+; CHECK-NEXT: ld.param.b8 %rs3, [global_variable_mask_param_2+2];
+; CHECK-NEXT: and.b16 %rs4, %rs3, 1;
+; CHECK-NEXT: ld.param.b8 %rs5, [global_variable_mask_param_2+1];
+; CHECK-NEXT: and.b16 %rs6, %rs5, 1;
+; CHECK-NEXT: setp.ne.b16 %p2, %rs6, 0;
+; CHECK-NEXT: ld.param.b8 %rs7, [global_variable_mask_param_2];
+; CHECK-NEXT: and.b16 %rs8, %rs7, 1;
+; CHECK-NEXT: setp.ne.b16 %p1, %rs8, 0;
+; CHECK-NEXT: ld.param.b64 %rd5, [global_variable_mask_param_1];
+; CHECK-NEXT: ld.param.b64 %rd6, [global_variable_mask_param_0];
+; CHECK-NEXT: ld.global.v4.b64 {%rd1, %rd2, %rd3, %rd4}, [%rd6];
+; CHECK-NEXT: not.pred %p5, %p1;
+; CHECK-NEXT: @%p5 bra $L__BB0_2;
+; CHECK-NEXT: // %bb.1: // %cond.store
+; CHECK-NEXT: st.global.b64 [%rd5], %rd1;
+; CHECK-NEXT: $L__BB0_2: // %else
+; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT: setp.ne.b16 %p3, %rs4, 0;
+; CHECK-NEXT: not.pred %p6, %p2;
+; CHECK-NEXT: @%p6 bra $L__BB0_4;
+; CHECK-NEXT: // %bb.3: // %cond.store1
+; CHECK-NEXT: st.global.b64 [%rd5+8], %rd2;
+; CHECK-NEXT: $L__BB0_4: // %else2
+; CHECK-NEXT: setp.ne.b16 %p4, %rs2, 0;
+; CHECK-NEXT: not.pred %p7, %p3;
+; CHECK-NEXT: @%p7 bra $L__BB0_6;
+; CHECK-NEXT: // %bb.5: // %cond.store3
+; CHECK-NEXT: st.global.b64 [%rd5+16], %rd3;
+; CHECK-NEXT: $L__BB0_6: // %else4
+; CHECK-NEXT: not.pred %p8, %p4;
+; CHECK-NEXT: @%p8 bra $L__BB0_8;
+; CHECK-NEXT: // %bb.7: // %cond.store5
+; CHECK-NEXT: st.global.b64 [%rd5+24], %rd4;
+; CHECK-NEXT: $L__BB0_8: // %else6
+; CHECK-NEXT: ret;
+ %a.load = load <4 x i64>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> %mask)
+ ret void
+}
+
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), <4 x i1>)
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
new file mode 100644
index 0000000000000..feb7b7e0a0b39
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
@@ -0,0 +1,318 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s -check-prefixes=CHECK,SM90
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK,SM100
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; This test is based on load-store-vectors.ll,
+; and contains testing for lowering 256-bit masked vector stores
+
+; Types we are checking: i32, i64, f32, f64
+
+; Address spaces we are checking: generic, global
+; - Global is the only address space that currently supports masked stores.
+; - The generic stores will get legalized before the backend via scalarization,
+; this file tests that even though we won't be generating them in the LSV.
+
+; 256-bit vector loads/stores are only legal for blackwell+, so on sm_90, the vectors will be split
+
+; generic address space
+
+define void @generic_8xi32(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xi32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<9>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [generic_8xi32_param_0];
+; CHECK-NEXT: ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT: ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [generic_8xi32_param_1];
+; CHECK-NEXT: st.b32 [%rd2], %r5;
+; CHECK-NEXT: st.b32 [%rd2+8], %r7;
+; CHECK-NEXT: st.b32 [%rd2+28], %r4;
+; CHECK-NEXT: ret;
+ %a.load = load <8 x i32>, ptr %a
+ tail call void @llvm.masked.store.v8i32.p0(<8 x i32> %a.load, ptr align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @generic_4xi64(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xi64(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [generic_4xi64_param_0];
+; CHECK-NEXT: ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT: ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd6, [generic_4xi64_param_1];
+; CHECK-NEXT: st.b64 [%rd6], %rd4;
+; CHECK-NEXT: st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT: ret;
+ %a.load = load <4 x i64>, ptr %a
+ tail call void @llvm.masked.store.v4i64.p0(<4 x i64> %a.load, ptr align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+define void @generic_8xfloat(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xfloat(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<9>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [generic_8xfloat_param_0];
+; CHECK-NEXT: ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT: ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd2, [generic_8xfloat_param_1];
+; CHECK-NEXT: st.b32 [%rd2], %r5;
+; CHECK-NEXT: st.b32 [%rd2+8], %r7;
+; CHECK-NEXT: st.b32 [%rd2+28], %r4;
+; CHECK-NEXT: ret;
+ %a.load = load <8 x float>, ptr %a
+ tail call void @llvm.masked.store.v8f32.p0(<8 x float> %a.load, ptr align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @generic_4xdouble(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xdouble(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [generic_4xdouble_param_0];
+; CHECK-NEXT: ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT: ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT: ld.param.b64 %rd6, [generic_4xdouble_param_1];
+; CHECK-NEXT: st.b64 [%rd6], %rd4;
+; CHECK-NEXT: st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT: ret;
+ %a.load = load <4 x double>, ptr %a
+ tail call void @llvm.masked.store.v4f64.p0(<4 x double> %a.load, ptr align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+; global address space
+
+define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r5;
+; SM90-NEXT: st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT: st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_8xi32(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT: ret;
+ %a.load = load <8 x i32>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @global_4xi64(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xi64(
+; SM90: {
+; SM90-NEXT: .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM90-NEXT: ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT: ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM90-NEXT: st.global.b64 [%rd6], %rd4;
+; SM90-NEXT: st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_4xi64(
+; SM100: {
+; SM100-NEXT: .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM100-NEXT: ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM100-NEXT: st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT: ret;
+ %a.load = load <4 x i64>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+define void @global_8xfloat(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xfloat(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r5;
+; SM90-NEXT: st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT: st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_8xfloat(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT: ret;
+ %a.load = load <8 x float>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v8f32.p1(<8 x float> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+ ret void
+}
+
+define void @global_4xdouble(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xdouble(
+; SM90: {
+; SM90-NEXT: .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM90-NEXT: ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT: ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM90-NEXT: st.global.b64 [%rd6], %rd4;
+; SM90-NEXT: st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_4xdouble(
+; SM100: {
+; SM100-NEXT: .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM100-NEXT: ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM100-NEXT: st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT: ret;
+ %a.load = load <4 x double>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v4f64.p1(<4 x double> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+ ret void
+}
+
+; edge cases
+define void @global_8xi32_all_mask_on(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32_all_mask_on(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1+16];
+; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
+; SM90-NEXT: st.global.v4.b32 [%rd2+16], {%r5, %r6, %r7, %r8};
+; SM90-NEXT: st.global.v4.b32 [%rd2], {%r1, %r2, %r3, %r4};
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: global_8xi32_all_mask_on(
+; SM100: {
+; SM100-NEXT: .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM100-NEXT: ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd6, [global_8xi32_all_mask_on_param_1];
+; SM100-NEXT: st.global.v4.b64 [%rd6], {%rd2, %rd3, %rd4, %rd5};
+; SM100-NEXT: ret;
+ %a.load = load <8 x i32>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+ ret void
+}
+
+define void @global_8xi32_all_mask_off(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_8xi32_all_mask_off(
+; CHECK: {
+; CHECK-EMPTY:
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ret;
+ %a.load = load <8 x i32>, ptr addrspace(1) %a
+ tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>)
+ ret void
+}
+
+; This is an example pattern for the LSV's output of these masked stores
+define void @vectorizerOutput(ptr addrspace(1) %in, ptr addrspace(1) %out) {
+; SM90-LABEL: vectorizerOutput(
+; SM90: {
+; SM90-NEXT: .reg .b32 %r<9>;
+; SM90-NEXT: .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT: // %bb.0:
+; SM90-NEXT: ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT: ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM90-NEXT: st.global.b32 [%rd2], %r5;
+; SM90-NEXT: st.global.b32 [%rd2+4], %r6;
+; SM90-NEXT: st.global.b32 [%rd2+12], %r8;
+; SM90-NEXT: st.global.b32 [%rd2+16], %r1;
+; SM90-NEXT: ret;
+;
+; SM100-LABEL: vectorizerOutput(
+; SM100: {
+; SM100-NEXT: .reg .b32 %r<9>;
+; SM100-NEXT: .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT: // %bb.0:
+; SM100-NEXT: ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT: ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, %r2, _, %r4, %r5, _, _, _};
+; SM100-NEXT: ret;
+ %1 = load <8 x i32>, ptr addrspace(1) %in, align 32
+ %load05 = extractelement <8 x i32> %1, i32 0
+ %load16 = extractelement <8 x i32> %1, i32 1
+ %load38 = extractelement <8 x i32> %1, i32 3
+ %load49 = extractelement <8 x i32> %1, i32 4
+ %2 = insertelement <8 x i32> poison, i32 %load05, i32 0
+ %3 = insertelement <8 x i32> %2, i32 %load16, i32 1
+ %4 = insertelement <8 x i32> %3, i32 poison, i32 2
+ %5 = insertelement <8 x i32> %4, i32 %load38, i32 3
+ %6 = insertelement <8 x i32> %5, i32 %load49, i32 4
+ %7 = insertelement <8 x i32> %6, i32 poison, i32 5
+ %8 = insertelement <8 x i32> %7, i32 poison, i32 6
+ %9 = insertelement <8 x i32> %8, i32 poison, i32 7
+ call void @llvm.masked.store.v8i32.p1(<8 x i32> %9, ptr addrspace(1) align 32 %out, <8 x i1> <i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false>)
+ ret void
+}
+
+declare void @llvm.masked.store.v8i32.p0(<8 x i32>, ptr, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p0(<4 x i64>, ptr, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p0(<4 x double>, ptr, <4 x i1>)
+
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), <8 x i1>)
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), <4 x i1>)
+declare void @llvm.masked.store.v8f32.p1(<8 x float>, ptr addrspace(1), <8 x i1>)
+declare void @llvm.masked.store.v4f64.p1(<4 x double>, ptr addrspace(1), <4 x i1>)
diff --git a/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir b/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir
index dfc84177fb0e6..a84b7fcd33836 100644
--- a/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir
+++ b/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir
@@ -77,7 +77,7 @@ constants: []
machineFunctionInfo: {}
body: |
bb.0:
- %0:b32, %1:b32, %2:b32, %3:b32 = LDV_i32_v4 0, 0, 101, 3, 32, &retval0, 0 :: (load (s128), addrspace 101)
+ %0:b32, %1:b32, %2:b32, %3:b32 = LDV_i32_v4 0, 0, 101, 3, 32, -1, &retval0, 0 :: (load (s128), addrspace 101)
; CHECK-NOT: ProxyReg
%4:b32 = ProxyRegB32 killed %0
%5:b32 = ProxyRegB32 killed %1
@@ -86,7 +86,7 @@ body: |
; CHECK: STV_i32_v4 killed %0, killed %1, killed %2, killed %3
STV_i32_v4 killed %4, killed %5, killed %6, killed %7, 0, 0, 101, 32, &func_retval0, 0 :: (store (s128), addrspace 101)
- %8:b32 = LD_i32 0, 0, 101, 3, 32, &retval0, 0 :: (load (s32), addrspace 101)
+ %8:b32 = LD_i32 0, 0, 101, 3, 32, -1, &retval0, 0 :: (load (s32), addrspace 101)
; CHECK-NOT: ProxyReg
%9:b32 = ProxyRegB32 killed %8
%10:b32 = ProxyRegB32 killed %9
More information about the llvm-commits
mailing list