[llvm] 17852de - [NVPTX] Lower LLVM masked vector loads and stores to PTX (#159387)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 25 08:26:20 PST 2025


Author: Drew Kersnar
Date: 2025-11-25T10:26:15-06:00
New Revision: 17852deda7fb9dabb41023e2673025c630b9369d

URL: https://github.com/llvm/llvm-project/commit/17852deda7fb9dabb41023e2673025c630b9369d
DIFF: https://github.com/llvm/llvm-project/commit/17852deda7fb9dabb41023e2673025c630b9369d.diff

LOG: [NVPTX] Lower LLVM masked vector loads and stores to PTX (#159387)

This backend support will allow the LoadStoreVectorizer, in certain
cases, to fill in gaps when creating load/store vectors and generate
LLVM masked load/stores
(https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics). To
accomplish this, changes are separated into two parts. This first part
has the backend lowering and TTI changes, and a follow up PR will have
the LSV generate these intrinsics:
https://github.com/llvm/llvm-project/pull/159388.

In this backend change, Masked Loads get lowered to PTX with `#pragma
"used_bytes_mask" [mask];`
(https://docs.nvidia.com/cuda/parallel-thread-execution/#pragma-strings-used-bytes-mask).
And Masked Stores get lowered to PTX using the new sink symbol syntax
(https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st).

# TTI Changes
TTI changes are needed because NVPTX only supports masked loads/stores
with _constant_ masks. `ScalarizeMaskedMemIntrin.cpp` is adjusted to
check that the mask is constant and pass that result into the TTI check.
Behavior shouldn't change for non-NVPTX targets, which do not care
whether the mask is variable or constant when determining legality, but
all TTI files that implement these API need to be updated.

# Masked store lowering implementation details
If the masked stores make it to the NVPTX backend without being
scalarized, they are handled by the following:
* `NVPTXISelLowering.cpp` - Sets up a custom operation action and
handles it in lowerMSTORE. Similar handling to normal store vectors,
except we read the mask and place a sentinel register `$noreg` in each
position where the mask reads as false.

For example, 
```
t10: v8i1 = BUILD_VECTOR Constant:i1<-1>, Constant:i1<0>, Constant:i1<0>, Constant:i1<-1>, Constant:i1<-1>, Constant:i1<0>, Constant:i1<0>, Constant:i1<-1>
t11: ch = masked_store<(store unknown-size into %ir.lsr.iv28, align 32, addrspace 1)> t5:1, t5, t7, undef:i64, t10

->

STV_i32_v8 killed %13:int32regs, $noreg, $noreg, killed %16:int32regs, killed %17:int32regs, $noreg, $noreg, killed %20:int32regs, 0, 0, 1, 8, 0, 32, %4:int64regs, 0, debug-location !18 :: (store unknown-size into %ir.lsr.iv28, align 32, addrspace 1);

```

* `NVPTXInstInfo.td` - changes the definition of store vectors to allow
for a mix of sink symbols and registers.
* `NVPXInstPrinter.h/.cpp` - Handles the `$noreg` case by printing "_".

# Masked load lowering implementation details
Masked loads are routed to normal PTX loads, with one difference: a
`#pragma "used_bytes_mask"` is emitted before the load instruction
(https://docs.nvidia.com/cuda/parallel-thread-execution/#pragma-strings-used-bytes-mask).
To accomplish this, a new operand is added to every NVPTXISD Load type
representing this mask.
* `NVPTXISelLowering.h/.cpp` - Masked loads are converted into normal
NVPTXISD loads with a mask operand in two ways. 1) In type legalization
through replaceLoadVector, which is the normal path, and 2) through
LowerMLOAD, to handle the legal vector types
(v2f16/v2bf16/v2i16/v4i8/v2f32) that will not be type legalized. Both
share the same convertMLOADToLoadWithUsedBytesMask helper. Both default
this operand to UINT32_MAX, representing all bytes on. For the latter,
we need a new `NVPTXISD::MLoadV1` type to represent that edge case
because we cannot put the used bytes mask operand on a generic
LoadSDNode.
* `NVPTXISelDAGToDAG.cpp` - Extract used bytes mask from loads, add them
to created machine instructions.
* `NVPTXInstPrinter.h/.cpp` - Print the pragma when the used bytes mask
isn't all ones.
* `NVPTXForwardParams.cpp`, `NVPTXReplaceImageHandles.cpp` - Update
manual indexing of load operands to account for new operand.
* `NVPTXInsrtInfo.td`, `NVPTXIntrinsics.td` - Add the used bytes mask to
the MI definitions.
* `NVPTXTagInvariantLoads.cpp` - Ensure that masked loads also get
tagged as invariant.

Some generic changes that are needed:
* `LegalizeVectorTypes.cpp` - Ensure flags are preserved when splitting
masked loads.
* `SelectionDAGBuilder.cpp` - Preserve `MD_invariant_load` on masked
load SDNode creation

Added: 
    llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
    llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
    llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
    llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
    llvm/lib/Target/ARM/ARMTargetTransformInfo.h
    llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
    llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
    llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
    llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
    llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
    llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXISelLowering.h
    llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
    llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
    llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
    llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp
    llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h
    llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
    llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
    llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
    llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
    llvm/lib/Target/VE/VETargetTransformInfo.h
    llvm/lib/Target/X86/X86TargetTransformInfo.cpp
    llvm/lib/Target/X86/X86TargetTransformInfo.h
    llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
    llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
    llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
    llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
    llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index bd4c901e9bc82..22cff2035eb0b 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -842,12 +842,20 @@ class TargetTransformInfo {
   LLVM_ABI AddressingModeKind
   getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
 
+  /// Some targets only support masked load/store with a constant mask.
+  enum MaskKind {
+    VariableOrConstantMask,
+    ConstantMask,
+  };
+
   /// Return true if the target supports masked store.
-  LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                   unsigned AddressSpace) const;
+  LLVM_ABI bool
+  isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace,
+                     MaskKind MaskKind = VariableOrConstantMask) const;
   /// Return true if the target supports masked load.
-  LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                                  unsigned AddressSpace) const;
+  LLVM_ABI bool
+  isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+                    MaskKind MaskKind = VariableOrConstantMask) const;
 
   /// Return true if the target supports nontemporal store.
   LLVM_ABI bool isLegalNTStore(Type *DataType, Align Alignment) const;

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 580b219ddbe53..4954c0d90a1e1 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,12 +309,14 @@ class TargetTransformInfoImplBase {
   }
 
   virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                  unsigned AddressSpace) const {
+                                  unsigned AddressSpace,
+                                  TTI::MaskKind MaskKind) const {
     return false;
   }
 
   virtual bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                                 unsigned AddressSpace) const {
+                                 unsigned AddressSpace,
+                                 TTI::MaskKind MaskKind) const {
     return false;
   }
 

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 46f90b4cec7c9..9768202a9ba26 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -468,13 +468,17 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
 }
 
 bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
-                                             unsigned AddressSpace) const {
-  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
+                                             unsigned AddressSpace,
+                                             TTI::MaskKind MaskKind) const {
+  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace,
+                                     MaskKind);
 }
 
 bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
-                                            unsigned AddressSpace) const {
-  return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace);
+                                            unsigned AddressSpace,
+                                            TTI::MaskKind MaskKind) const {
+  return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace,
+                                    MaskKind);
 }
 
 bool TargetTransformInfo::isLegalNTStore(Type *DataType,

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 24a18e181ba80..4274e951446b8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -2465,6 +2465,7 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
   SDValue PassThru = MLD->getPassThru();
   Align Alignment = MLD->getBaseAlign();
   ISD::LoadExtType ExtType = MLD->getExtensionType();
+  MachineMemOperand::Flags MMOFlags = MLD->getMemOperand()->getFlags();
 
   // Split Mask operand
   SDValue MaskLo, MaskHi;
@@ -2490,9 +2491,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
     std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, dl);
 
   MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
-      MLD->getPointerInfo(), MachineMemOperand::MOLoad,
-      LocationSize::beforeOrAfterPointer(), Alignment, MLD->getAAInfo(),
-      MLD->getRanges());
+      MLD->getPointerInfo(), MMOFlags, LocationSize::beforeOrAfterPointer(),
+      Alignment, MLD->getAAInfo(), MLD->getRanges());
 
   Lo = DAG.getMaskedLoad(LoVT, dl, Ch, Ptr, Offset, MaskLo, PassThruLo, LoMemVT,
                          MMO, MLD->getAddressingMode(), ExtType,
@@ -2515,8 +2515,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
           LoMemVT.getStoreSize().getFixedValue());
 
     MMO = DAG.getMachineFunction().getMachineMemOperand(
-        MPI, MachineMemOperand::MOLoad, LocationSize::beforeOrAfterPointer(),
-        Alignment, MLD->getAAInfo(), MLD->getRanges());
+        MPI, MMOFlags, LocationSize::beforeOrAfterPointer(), Alignment,
+        MLD->getAAInfo(), MLD->getRanges());
 
     Hi = DAG.getMaskedLoad(HiVT, dl, Ch, Ptr, Offset, MaskHi, PassThruHi,
                            HiMemVT, MMO, MLD->getAddressingMode(), ExtType,

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 985a54ca83256..88b35582a9f7d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -5063,6 +5063,8 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
   auto MMOFlags = MachineMemOperand::MOLoad;
   if (I.hasMetadata(LLVMContext::MD_nontemporal))
     MMOFlags |= MachineMemOperand::MONonTemporal;
+  if (I.hasMetadata(LLVMContext::MD_invariant_load))
+    MMOFlags |= MachineMemOperand::MOInvariant;
 
   MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
       MachinePointerInfo(PtrOperand), MMOFlags,

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 6cc4987428567..52fc28a98449b 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -323,12 +323,14 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
   }
 
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned /*AddressSpace*/) const override {
+                         unsigned /*AddressSpace*/,
+                         TTI::MaskKind /*MaskKind*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/,
+                          TTI::MaskKind /*MaskKind*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index d12b802fe234f..fdb0ec40cb41f 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1125,7 +1125,8 @@ bool ARMTTIImpl::isProfitableLSRChainElement(Instruction *I) const {
 }
 
 bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
-                                   unsigned /*AddressSpace*/) const {
+                                   unsigned /*AddressSpace*/,
+                                   TTI::MaskKind /*MaskKind*/) const {
   if (!EnableMaskedLoadStores || !ST->hasMVEIntegerOps())
     return false;
 

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 919a6fc9fd0b0..30f2151b41239 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -186,12 +186,16 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
 
   bool isProfitableLSRChainElement(Instruction *I) const override;
 
-  bool isLegalMaskedLoad(Type *DataTy, Align Alignment,
-                         unsigned AddressSpace) const override;
-
-  bool isLegalMaskedStore(Type *DataTy, Align Alignment,
-                          unsigned AddressSpace) const override {
-    return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
+  bool
+  isLegalMaskedLoad(Type *DataTy, Align Alignment, unsigned AddressSpace,
+                    TTI::MaskKind MaskKind =
+                        TTI::MaskKind::VariableOrConstantMask) const override;
+
+  bool
+  isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
+                     TTI::MaskKind MaskKind =
+                         TTI::MaskKind::VariableOrConstantMask) const override {
+    return isLegalMaskedLoad(DataTy, Alignment, AddressSpace, MaskKind);
   }
 
   bool forceScalarizeMaskedGather(VectorType *VTy,

diff  --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 8f3f0cc8abb01..3f84cbb6555ed 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -343,14 +343,16 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
 }
 
 bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
-                                        unsigned /*AddressSpace*/) const {
+                                        unsigned /*AddressSpace*/,
+                                        TTI::MaskKind /*MaskKind*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
 }
 
 bool HexagonTTIImpl::isLegalMaskedLoad(Type *DataType, Align /*Alignment*/,
-                                       unsigned /*AddressSpace*/) const {
+                                       unsigned /*AddressSpace*/,
+                                       TTI::MaskKind /*MaskKind*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);

diff  --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index e95b5a10b76a7..67388984bb3e3 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -165,9 +165,10 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace) const override;
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned AddressSpace) const override;
+                          unsigned AddressSpace,
+                          TTI::MaskKind MaskKind) const override;
+  bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+                         TTI::MaskKind MaskKind) const override;
   bool isLegalMaskedGather(Type *Ty, Align Alignment) const override;
   bool isLegalMaskedScatter(Type *Ty, Align Alignment) const override;
   bool forceScalarizeMaskedGather(VectorType *VTy,

diff  --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 77913f27838e2..5ff5fa36ac467 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -395,6 +395,25 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
   }
 }
 
+void NVPTXInstPrinter::printUsedBytesMaskPragma(const MCInst *MI, int OpNum,
+                                                raw_ostream &O) {
+  auto &Op = MI->getOperand(OpNum);
+  assert(Op.isImm() && "Invalid operand");
+  uint32_t Imm = (uint32_t)Op.getImm();
+  if (Imm != UINT32_MAX) {
+    O << ".pragma \"used_bytes_mask " << format_hex(Imm, 1) << "\";\n\t";
+  }
+}
+
+void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
+                                                 raw_ostream &O) {
+  const MCOperand &Op = MI->getOperand(OpNum);
+  if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
+    O << "_";
+  else
+    printOperand(MI, OpNum, O);
+}
+
 void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
                                       raw_ostream &O) {
   int64_t Imm = MI->getOperand(OpNum).getImm();

diff  --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 92155b01464e8..3d172441adfcc 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,6 +46,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
                     StringRef Modifier = {});
   void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
                        StringRef Modifier = {});
+  void printUsedBytesMaskPragma(const MCInst *MI, int OpNum, raw_ostream &O);
+  void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O);
   void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
   void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
   void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);

diff  --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
index a3496090def3c..c8b53571c1e59 100644
--- a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
@@ -96,7 +96,7 @@ static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
   const MachineOperand *ParamSymbol = Mov.uses().begin();
   assert(ParamSymbol->isSymbol());
 
-  constexpr unsigned LDInstBasePtrOpIdx = 5;
+  constexpr unsigned LDInstBasePtrOpIdx = 6;
   constexpr unsigned LDInstAddrSpaceOpIdx = 2;
   for (auto *LI : LoadInsts) {
     (LI->uses().begin() + LDInstBasePtrOpIdx)

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 996d653940118..0e1125ab8d8b3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -105,6 +105,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
   switch (N->getOpcode()) {
   case ISD::LOAD:
   case ISD::ATOMIC_LOAD:
+  case NVPTXISD::MLoad:
     if (tryLoad(N))
       return;
     break;
@@ -1132,6 +1133,19 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
           ? NVPTX::PTXLdStInstCode::Signed
           : NVPTX::PTXLdStInstCode::Untyped;
 
+  uint32_t UsedBytesMask;
+  switch (N->getOpcode()) {
+  case ISD::LOAD:
+  case ISD::ATOMIC_LOAD:
+    UsedBytesMask = UINT32_MAX;
+    break;
+  case NVPTXISD::MLoad:
+    UsedBytesMask = N->getConstantOperandVal(3);
+    break;
+  default:
+    llvm_unreachable("Unexpected opcode");
+  }
+
   assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
          FromTypeWidth <= 128 && "Invalid width for load");
 
@@ -1142,6 +1156,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
                    getI32Imm(CodeAddrSpace, DL),
                    getI32Imm(FromType, DL),
                    getI32Imm(FromTypeWidth, DL),
+                   getI32Imm(UsedBytesMask, DL),
                    Base,
                    Offset,
                    Chain};
@@ -1196,14 +1211,14 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
   //          type is integer
   // Float  : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
   // Read at least 8 bits (predicates are stored as 8-bit values)
-  // The last operand holds the original LoadSDNode::getExtensionType() value
-  const unsigned ExtensionType =
-      N->getConstantOperandVal(N->getNumOperands() - 1);
+  // Get the original LoadSDNode::getExtensionType() value
+  const unsigned ExtensionType = N->getConstantOperandVal(4);
   const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
                                 ? NVPTX::PTXLdStInstCode::Signed
                                 : NVPTX::PTXLdStInstCode::Untyped;
 
   const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
+  const uint32_t UsedBytesMask = N->getConstantOperandVal(3);
 
   assert(!(EltVT.isVector() && ExtensionType != ISD::NON_EXTLOAD));
 
@@ -1213,6 +1228,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
                    getI32Imm(CodeAddrSpace, DL),
                    getI32Imm(FromType, DL),
                    getI32Imm(FromTypeWidth, DL),
+                   getI32Imm(UsedBytesMask, DL),
                    Base,
                    Offset,
                    Chain};
@@ -1250,10 +1266,13 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
   SDLoc DL(LD);
 
   unsigned ExtensionType;
+  uint32_t UsedBytesMask;
   if (const auto *Load = dyn_cast<LoadSDNode>(LD)) {
     ExtensionType = Load->getExtensionType();
+    UsedBytesMask = UINT32_MAX;
   } else {
-    ExtensionType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
+    ExtensionType = LD->getConstantOperandVal(4);
+    UsedBytesMask = LD->getConstantOperandVal(3);
   }
   const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
                                 ? NVPTX::PTXLdStInstCode::Signed
@@ -1265,8 +1284,12 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
            ExtensionType != ISD::NON_EXTLOAD));
 
   const auto [Base, Offset] = selectADDR(LD->getOperand(1), CurDAG);
-  SDValue Ops[] = {getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), Base,
-                   Offset, LD->getChain()};
+  SDValue Ops[] = {getI32Imm(FromType, DL),
+                   getI32Imm(FromTypeWidth, DL),
+                   getI32Imm(UsedBytesMask, DL),
+                   Base,
+                   Offset,
+                   LD->getChain()};
 
   const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
   std::optional<unsigned> Opcode;
@@ -1277,6 +1300,10 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
     Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i16,
                              NVPTX::LD_GLOBAL_NC_i32, NVPTX::LD_GLOBAL_NC_i64);
     break;
+  case NVPTXISD::MLoad:
+    Opcode = pickOpcodeForVT(TargetVT, std::nullopt, NVPTX::LD_GLOBAL_NC_i32,
+                             NVPTX::LD_GLOBAL_NC_i64);
+    break;
   case NVPTXISD::LoadV2:
     Opcode =
         pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v2i16,

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 1773958520f04..f3a3bc785d997 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -770,7 +770,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
                      Custom);
   for (MVT VT : MVT::fixedlen_vector_valuetypes())
     if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
-      setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
+      setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT,
+                         Custom);
 
   // Custom legalization for LDU intrinsics.
   // TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -3133,6 +3134,86 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
   return Or;
 }
 
+static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
+  SDNode *N = Op.getNode();
+
+  SDValue Chain = N->getOperand(0);
+  SDValue Val = N->getOperand(1);
+  SDValue BasePtr = N->getOperand(2);
+  SDValue Offset = N->getOperand(3);
+  SDValue Mask = N->getOperand(4);
+
+  SDLoc DL(N);
+  EVT ValVT = Val.getValueType();
+  MemSDNode *MemSD = cast<MemSDNode>(N);
+  assert(ValVT.isVector() && "Masked vector store must have vector type");
+  assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
+         "Unexpected alignment for masked store");
+
+  unsigned Opcode = 0;
+  switch (ValVT.getSimpleVT().SimpleTy) {
+  default:
+    llvm_unreachable("Unexpected masked vector store type");
+  case MVT::v4i64:
+  case MVT::v4f64: {
+    Opcode = NVPTXISD::StoreV4;
+    break;
+  }
+  case MVT::v8i32:
+  case MVT::v8f32: {
+    Opcode = NVPTXISD::StoreV8;
+    break;
+  }
+  }
+
+  SmallVector<SDValue, 8> Ops;
+
+  // Construct the new SDNode. First operand is the chain.
+  Ops.push_back(Chain);
+
+  // The next N operands are the values to store. Encode the mask into the
+  // values using the sentinel register 0 to represent a masked-off element.
+  assert(Mask.getValueType().isVector() &&
+         Mask.getValueType().getVectorElementType() == MVT::i1 &&
+         "Mask must be a vector of i1");
+  assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
+         "Mask expected to be a BUILD_VECTOR");
+  assert(Mask.getValueType().getVectorNumElements() ==
+             ValVT.getVectorNumElements() &&
+         "Mask size must be the same as the vector size");
+  for (auto [I, Op] : enumerate(Mask->ops())) {
+    // Mask elements must be constants.
+    if (Op.getNode()->getAsZExtVal() == 0) {
+      // Append a sentinel register 0 to the Ops vector to represent a masked
+      // off element, this will be handled in tablegen
+      Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
+                                    ValVT.getVectorElementType()));
+    } else {
+      // Extract the element from the vector to store
+      SDValue ExtVal =
+          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(),
+                      Val, DAG.getIntPtrConstant(I, DL));
+      Ops.push_back(ExtVal);
+    }
+  }
+
+  // Next, the pointer operand.
+  Ops.push_back(BasePtr);
+
+  // Finally, the offset operand. We expect this to always be undef, and it will
+  // be ignored in lowering, but to mirror the handling of the other vector
+  // store instructions we include it in the new SDNode.
+  assert(Offset.getOpcode() == ISD::UNDEF &&
+         "Offset operand expected to be undef");
+  Ops.push_back(Offset);
+
+  SDValue NewSt =
+      DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+                              MemSD->getMemoryVT(), MemSD->getMemOperand());
+
+  return NewSt;
+}
+
 SDValue
 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -3169,8 +3250,16 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerVECREDUCE(Op, DAG);
   case ISD::STORE:
     return LowerSTORE(Op, DAG);
+  case ISD::MSTORE: {
+    assert(STI.has256BitVectorLoadStore(
+               cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
+           "Masked store vector not supported on subtarget.");
+    return lowerMSTORE(Op, DAG);
+  }
   case ISD::LOAD:
     return LowerLOAD(Op, DAG);
+  case ISD::MLOAD:
+    return LowerMLOAD(Op, DAG);
   case ISD::SHL_PARTS:
     return LowerShiftLeftParts(Op, DAG);
   case ISD::SRA_PARTS:
@@ -3363,10 +3452,56 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
                       MachinePointerInfo(SV));
 }
 
+static std::pair<MemSDNode *, uint32_t>
+convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
+  SDValue Chain = N->getOperand(0);
+  SDValue BasePtr = N->getOperand(1);
+  SDValue Mask = N->getOperand(3);
+  SDValue Passthru = N->getOperand(4);
+
+  SDLoc DL(N);
+  EVT ResVT = N->getValueType(0);
+  assert(ResVT.isVector() && "Masked vector load must have vector type");
+  // While we only expect poison passthru vectors as an input to the backend,
+  // when the legalization framework splits a poison vector in half, it creates
+  // two undef vectors, so we can technically expect those too.
+  assert((Passthru.getOpcode() == ISD::POISON ||
+          Passthru.getOpcode() == ISD::UNDEF) &&
+         "Passthru operand expected to be poison or undef");
+
+  // Extract the mask and convert it to a uint32_t representing the used bytes
+  // of the entire vector load
+  uint32_t UsedBytesMask = 0;
+  uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
+  assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
+  uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
+  uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
+
+  for (SDValue Op : reverse(Mask->ops())) {
+    // We technically only want to do this shift for every
+    // iteration *but* the first, but in the first iteration UsedBytesMask is 0,
+    // so this shift is a no-op.
+    UsedBytesMask <<= ElementSizeInBytes;
+
+    // Mask elements must be constants.
+    if (Op->getAsZExtVal() != 0)
+      UsedBytesMask |= ElementMask;
+  }
+
+  assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
+         "Unexpected masked load with elements masked all on or all off");
+
+  // Create a new load sd node to be handled normally by ReplaceLoadVector.
+  MemSDNode *NewLD = cast<MemSDNode>(
+      DAG.getLoad(ResVT, DL, Chain, BasePtr, N->getMemOperand()).getNode());
+
+  return {NewLD, UsedBytesMask};
+}
+
 /// replaceLoadVector - Convert vector loads into multi-output scalar loads.
 static std::optional<std::pair<SDValue, SDValue>>
 replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
-  LoadSDNode *LD = cast<LoadSDNode>(N);
+  MemSDNode *LD = cast<MemSDNode>(N);
   const EVT ResVT = LD->getValueType(0);
   const EVT MemVT = LD->getMemoryVT();
 
@@ -3393,6 +3528,11 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
     return std::nullopt;
   }
 
+  // If we have a masked load, convert it to a normal load now
+  std::optional<uint32_t> UsedBytesMask = std::nullopt;
+  if (LD->getOpcode() == ISD::MLOAD)
+    std::tie(LD, UsedBytesMask) = convertMLOADToLoadWithUsedBytesMask(LD, DAG);
+
   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
   // loaded type to i16 and propagate the "real" type as the memory type.
@@ -3421,9 +3561,13 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
   // Copy regular operands
   SmallVector<SDValue, 8> OtherOps(LD->ops());
 
+  OtherOps.push_back(
+      DAG.getConstant(UsedBytesMask.value_or(UINT32_MAX), DL, MVT::i32));
+
   // The select routine does not have access to the LoadSDNode instance, so
   // pass along the extension information
-  OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
+  OtherOps.push_back(
+      DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
 
   SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT,
                                           LD->getMemOperand());
@@ -3511,6 +3655,42 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
   llvm_unreachable("Unexpected custom lowering for load");
 }
 
+SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
+  // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
+  // masked loads of these types and have to handle them here.
+  // v2f32 also needs to be handled here if the subtarget has f32x2
+  // instructions, making it legal.
+  //
+  // Note: misaligned masked loads should never reach this point
+  // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
+  // will validate alignment. Therefore, we do not need to special case handle
+  // them here.
+  EVT VT = Op.getValueType();
+  if (NVPTX::isPackedVectorTy(VT)) {
+    auto Result =
+        convertMLOADToLoadWithUsedBytesMask(cast<MemSDNode>(Op.getNode()), DAG);
+    MemSDNode *LD = std::get<0>(Result);
+    uint32_t UsedBytesMask = std::get<1>(Result);
+
+    SDLoc DL(LD);
+
+    // Copy regular operands
+    SmallVector<SDValue, 8> OtherOps(LD->ops());
+
+    OtherOps.push_back(DAG.getConstant(UsedBytesMask, DL, MVT::i32));
+
+    // We currently are not lowering extending loads, but pass the extension
+    // type anyway as later handling expects it.
+    OtherOps.push_back(
+        DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
+    SDValue NewLD =
+        DAG.getMemIntrinsicNode(NVPTXISD::MLoad, DL, LD->getVTList(), OtherOps,
+                                LD->getMemoryVT(), LD->getMemOperand());
+    return NewLD;
+  }
+  return SDValue();
+}
+
 static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG,
                                 const NVPTXSubtarget &STI) {
   MemSDNode *N = cast<MemSDNode>(Op.getNode());
@@ -5419,6 +5599,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
     // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
     // here.
     Opcode = NVPTXISD::LoadV2;
+    // append a "full" used bytes mask operand right before the extension type
+    // operand, signifying that all bytes are used.
+    Operands.push_back(DCI.DAG.getConstant(UINT32_MAX, DL, MVT::i32));
     Operands.push_back(DCI.DAG.getIntPtrConstant(
         cast<LoadSDNode>(LD)->getExtensionType(), DL));
     break;
@@ -5427,9 +5610,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
     Opcode = NVPTXISD::LoadV4;
     break;
   case NVPTXISD::LoadV4:
-    // V8 is only supported for f32. Don't forget, we're not changing the load
-    // size here. This is already a 256-bit load.
-    if (ElementVT != MVT::v2f32)
+    // V8 is only supported for f32/i32. Don't forget, we're not changing the
+    // load size here. This is already a 256-bit load.
+    if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
       return SDValue();
     OldNumOutputs = 4;
     Opcode = NVPTXISD::LoadV8;
@@ -5504,9 +5687,9 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
     Opcode = NVPTXISD::StoreV4;
     break;
   case NVPTXISD::StoreV4:
-    // V8 is only supported for f32. Don't forget, we're not changing the store
-    // size here. This is already a 256-bit store.
-    if (ElementVT != MVT::v2f32)
+    // V8 is only supported for f32/i32. Don't forget, we're not changing the
+    // store size here. This is already a 256-bit store.
+    if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
       return SDValue();
     Opcode = NVPTXISD::StoreV8;
     break;
@@ -6657,6 +6840,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
     ReplaceBITCAST(N, DAG, Results);
     return;
   case ISD::LOAD:
+  case ISD::MLOAD:
     replaceLoadVector(N, DAG, Results, STI);
     return;
   case ISD::INTRINSIC_W_CHAIN:

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index d71a86fd463f6..dd8e49de7aa6a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -235,6 +235,7 @@ class NVPTXTargetLowering : public TargetLowering {
   SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
 
   SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
 

diff  --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 04e2dd435cdf0..feefaf9a21e5b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1588,6 +1588,14 @@ def ADDR : Operand<pAny> {
   let MIOperandInfo = (ops ADDR_base, i32imm);
 }
 
+def UsedBytesMask : Operand<i32> {
+  let PrintMethod = "printUsedBytesMaskPragma";
+}
+
+def RegOrSink : Operand<Any> {
+  let PrintMethod = "printRegisterOrSinkSymbol";
+}
+
 def AtomicCode : Operand<i32> {
   let PrintMethod = "printAtomicCode";
 }
@@ -1832,8 +1840,10 @@ def Callseq_End :
 class LD<NVPTXRegClass regclass>
   : NVPTXInst<
     (outs regclass:$dst),
-    (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, AtomicCode:$Sign,
-         i32imm:$fromWidth, ADDR:$addr),
+    (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
+         AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+         ADDR:$addr),
+    "${usedBytes}"
     "ld${sem:sem}${scope:scope}${addsp:addsp}.${Sign:sign}$fromWidth "
     "\t$dst, [$addr];">;
 
@@ -1865,21 +1875,27 @@ multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
   def _v2 : NVPTXInst<
     (outs regclass:$dst1, regclass:$dst2),
     (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
-         AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
+         AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+         ADDR:$addr),
+    "${usedBytes}"
     "ld${sem:sem}${scope:scope}${addsp:addsp}.v2.${Sign:sign}$fromWidth "
     "\t{{$dst1, $dst2}}, [$addr];">;
   def _v4 : NVPTXInst<
     (outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4),
     (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
-         AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
+         AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+         ADDR:$addr),
+    "${usedBytes}"
     "ld${sem:sem}${scope:scope}${addsp:addsp}.v4.${Sign:sign}$fromWidth "
     "\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr];">;
   if support_v8 then
     def _v8 : NVPTXInst<
       (outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
             regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
-      (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, AtomicCode:$Sign,
-           i32imm:$fromWidth, ADDR:$addr),
+      (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
+           AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+           ADDR:$addr),
+      "${usedBytes}"
       "ld${sem:sem}${scope:scope}${addsp:addsp}.v8.${Sign:sign}$fromWidth "
       "\t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, "
       "[$addr];">;
@@ -1900,7 +1916,7 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
     "\t[$addr], {{$src1, $src2}};">;
   def _v4 : NVPTXInst<
     (outs),
-    (ins O:$src1, O:$src2, O:$src3, O:$src4,
+    (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
          AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
          ADDR:$addr),
     "st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth "
@@ -1908,8 +1924,8 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
   if support_v8 then
     def _v8 : NVPTXInst<
       (outs),
-      (ins O:$src1, O:$src2, O:$src3, O:$src4,
-           O:$src5, O:$src6, O:$src7, O:$src8,
+      (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
+           RegOrSink:$src5, RegOrSink:$src6, RegOrSink:$src7, RegOrSink:$src8,
            AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
            ADDR:$addr),
       "st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth "

diff  --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 8501d4d7bb86f..d18c7e20df038 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2552,7 +2552,10 @@ def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<B32>;
 // during the lifetime of the kernel.
 
 class LDG_G<NVPTXRegClass regclass>
-  : NVPTXInst<(outs regclass:$result), (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+  : NVPTXInst<(outs regclass:$result),
+              (ins AtomicCode:$Sign, i32imm:$fromWidth,
+                   UsedBytesMask:$usedBytes, ADDR:$src),
+               "${usedBytes}"
                "ld.global.nc.${Sign:sign}$fromWidth \t$result, [$src];">;
 
 def LD_GLOBAL_NC_i16 : LDG_G<B16>;
@@ -2564,19 +2567,25 @@ def LD_GLOBAL_NC_i64 : LDG_G<B64>;
 // Elementized vector ldg
 class VLDG_G_ELE_V2<NVPTXRegClass regclass> :
   NVPTXInst<(outs regclass:$dst1, regclass:$dst2),
-            (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+            (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+             ADDR:$src),
+            "${usedBytes}"
             "ld.global.nc.v2.${Sign:sign}$fromWidth \t{{$dst1, $dst2}}, [$src];">;
 
 
 class VLDG_G_ELE_V4<NVPTXRegClass regclass> :
   NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4), 
-            (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+            (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+             ADDR:$src),
+            "${usedBytes}"
             "ld.global.nc.v4.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];">;
 
 class VLDG_G_ELE_V8<NVPTXRegClass regclass> :
   NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
                   regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
-             (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+            (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+             ADDR:$src),
+            "${usedBytes}"
              "ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];">;
 
 // FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.

diff  --git a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
index 320c0fb6950a7..4bbf49f93f43b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
@@ -1808,8 +1808,8 @@ bool NVPTXReplaceImageHandles::replaceImageHandle(MachineOperand &Op,
       // For CUDA, we preserve the param loads coming from function arguments
       return false;
 
-    assert(TexHandleDef.getOperand(6).isSymbol() && "Load is not a symbol!");
-    StringRef Sym = TexHandleDef.getOperand(6).getSymbolName();
+    assert(TexHandleDef.getOperand(7).isSymbol() && "Load is not a symbol!");
+    StringRef Sym = TexHandleDef.getOperand(7).getSymbolName();
     InstrsToRemove.insert(&TexHandleDef);
     Op.ChangeToES(Sym.data());
     MFI->getImageHandleSymbolIndex(Sym);

diff  --git a/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp
index e8ea1ad6c404d..710d063e75725 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp
@@ -30,6 +30,7 @@ const char *NVPTXSelectionDAGInfo::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::LoadV2)
     MAKE_CASE(NVPTXISD::LoadV4)
     MAKE_CASE(NVPTXISD::LoadV8)
+    MAKE_CASE(NVPTXISD::MLoad)
     MAKE_CASE(NVPTXISD::LDUV2)
     MAKE_CASE(NVPTXISD::LDUV4)
     MAKE_CASE(NVPTXISD::StoreV2)

diff  --git a/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h b/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h
index 07c130baeaa4f..9dd0a1eaa5856 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h
@@ -36,6 +36,7 @@ enum NodeType : unsigned {
   LoadV2,
   LoadV4,
   LoadV8,
+  MLoad,
   LDUV2, // LDU.v2
   LDUV4, // LDU.v4
   StoreV2,

diff  --git a/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp b/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
index a4aff44ac04f6..f1774a7c5572e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
@@ -27,13 +27,14 @@
 
 using namespace llvm;
 
-static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) {
+static bool isInvariantLoad(const Instruction *I, const Value *Ptr,
+                            const bool IsKernelFn) {
   // Don't bother with non-global loads
-  if (LI->getPointerAddressSpace() != NVPTXAS::ADDRESS_SPACE_GLOBAL)
+  if (Ptr->getType()->getPointerAddressSpace() != NVPTXAS::ADDRESS_SPACE_GLOBAL)
     return false;
 
   // If the load is already marked as invariant, we don't need to do anything
-  if (LI->getMetadata(LLVMContext::MD_invariant_load))
+  if (I->getMetadata(LLVMContext::MD_invariant_load))
     return false;
 
   // We use getUnderlyingObjects() here instead of getUnderlyingObject()
@@ -41,7 +42,7 @@ static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) {
   // not. We need to look through phi nodes to handle pointer induction
   // variables.
   SmallVector<const Value *, 8> Objs;
-  getUnderlyingObjects(LI->getPointerOperand(), Objs);
+  getUnderlyingObjects(Ptr, Objs);
 
   return all_of(Objs, [&](const Value *V) {
     if (const auto *A = dyn_cast<const Argument>(V))
@@ -53,9 +54,9 @@ static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) {
   });
 }
 
-static void markLoadsAsInvariant(LoadInst *LI) {
-  LI->setMetadata(LLVMContext::MD_invariant_load,
-                  MDNode::get(LI->getContext(), {}));
+static void markLoadsAsInvariant(Instruction *I) {
+  I->setMetadata(LLVMContext::MD_invariant_load,
+                 MDNode::get(I->getContext(), {}));
 }
 
 static bool tagInvariantLoads(Function &F) {
@@ -63,12 +64,17 @@ static bool tagInvariantLoads(Function &F) {
 
   bool Changed = false;
   for (auto &I : instructions(F)) {
-    if (auto *LI = dyn_cast<LoadInst>(&I)) {
-      if (isInvariantLoad(LI, IsKernelFn)) {
+    if (auto *LI = dyn_cast<LoadInst>(&I))
+      if (isInvariantLoad(LI, LI->getPointerOperand(), IsKernelFn)) {
         markLoadsAsInvariant(LI);
         Changed = true;
       }
-    }
+    if (auto *II = dyn_cast<IntrinsicInst>(&I))
+      if (II->getIntrinsicID() == Intrinsic::masked_load &&
+          isInvariantLoad(II, II->getOperand(0), IsKernelFn)) {
+        markLoadsAsInvariant(II);
+        Changed = true;
+      }
   }
   return Changed;
 }

diff  --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 64593e6439184..5d5553c573b0f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -592,6 +592,45 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
   return nullptr;
 }
 
+bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
+                                      unsigned AddrSpace,
+                                      TTI::MaskKind MaskKind) const {
+  if (MaskKind != TTI::MaskKind::ConstantMask)
+    return false;
+
+  //  We currently only support this feature for 256-bit vectors, so the
+  //  alignment must be at least 32
+  if (Alignment < 32)
+    return false;
+
+  if (!ST->has256BitVectorLoadStore(AddrSpace))
+    return false;
+
+  auto *VTy = dyn_cast<FixedVectorType>(DataTy);
+  if (!VTy)
+    return false;
+
+  auto *ElemTy = VTy->getScalarType();
+  return (ElemTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
+         (ElemTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4);
+}
+
+bool NVPTXTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
+                                     unsigned /*AddrSpace*/,
+                                     TTI::MaskKind MaskKind) const {
+  if (MaskKind != TTI::MaskKind::ConstantMask)
+    return false;
+
+  if (Alignment < DL.getTypeStoreSize(DataTy))
+    return false;
+
+  // We do not support sub-byte element type masked loads.
+  auto *VTy = dyn_cast<FixedVectorType>(DataTy);
+  if (!VTy)
+    return false;
+  return VTy->getElementType()->getScalarSizeInBits() >= 8;
+}
+
 unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
   // 256 bit loads/stores are currently only supported for global address space
   if (ST->has256BitVectorLoadStore(AddrSpace))

diff  --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 78eb751cf3c2e..d7f4e1da4073b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,6 +181,12 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                   Intrinsic::ID IID) const override;
 
+  bool isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddrSpace,
+                          TTI::MaskKind MaskKind) const override;
+
+  bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddrSpace,
+                         TTI::MaskKind MaskKind) const override;
+
   unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
 
   Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,

diff  --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 39c1173e2986c..484c4791390ac 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -285,11 +285,13 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
   }
 
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned /*AddressSpace*/) const override {
+                         unsigned /*AddressSpace*/,
+                         TTI::MaskKind /*MaskKind*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/,
+                          TTI::MaskKind /*MaskKind*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 

diff  --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 5c0ddca62c761..eed3832c9f1fb 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -134,12 +134,14 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
   }
 
   // Load & Store {
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned /*AddressSpace*/) const override {
+  bool
+  isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned /*AddressSpace*/,
+                    TargetTransformInfo::MaskKind /*MaskKind*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
-  bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+  bool isLegalMaskedStore(
+      Type *DataType, Align Alignment, unsigned /*AddressSpace*/,
+      TargetTransformInfo::MaskKind /*MaskKind*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 6ae55f3623096..a9dc96b53d530 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6322,7 +6322,8 @@ static bool isLegalMaskedLoadStore(Type *ScalarTy, const X86Subtarget *ST) {
 }
 
 bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
-                                   unsigned AddressSpace) const {
+                                   unsigned AddressSpace,
+                                   TTI::MaskKind MaskKind) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
@@ -6335,7 +6336,8 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
 }
 
 bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
-                                    unsigned AddressSpace) const {
+                                    unsigned AddressSpace,
+                                    TTI::MaskKind MaskKind) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.

diff  --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index f8866472bd9af..d6dea9427990b 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -267,10 +267,14 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                      const TargetTransformInfo::LSRCost &C2) const override;
   bool canMacroFuseCmp() const override;
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned AddressSpace) const override;
-  bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace) const override;
+  bool
+  isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+                    TTI::MaskKind MaskKind =
+                        TTI::MaskKind::VariableOrConstantMask) const override;
+  bool
+  isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace,
+                     TTI::MaskKind MaskKind =
+                         TTI::MaskKind::VariableOrConstantMask) const override;
   bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
   bool isLegalNTStore(Type *DataType, Align Alignment) const override;
   bool isLegalBroadcastLoad(Type *ElementTy,

diff  --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 146e7d1047dd0..b7b08ae61ec52 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1123,7 +1123,10 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
       if (TTI.isLegalMaskedLoad(
               CI->getType(), CI->getParamAlign(0).valueOrOne(),
               cast<PointerType>(CI->getArgOperand(0)->getType())
-                  ->getAddressSpace()))
+                  ->getAddressSpace(),
+              isConstantIntVector(CI->getArgOperand(1))
+                  ? TTI::MaskKind::ConstantMask
+                  : TTI::MaskKind::VariableOrConstantMask))
         return false;
       scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
@@ -1132,7 +1135,10 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getArgOperand(0)->getType(),
               CI->getParamAlign(1).valueOrOne(),
               cast<PointerType>(CI->getArgOperand(1)->getType())
-                  ->getAddressSpace()))
+                  ->getAddressSpace(),
+              isConstantIntVector(CI->getArgOperand(2))
+                  ? TTI::MaskKind::ConstantMask
+                  : TTI::MaskKind::VariableOrConstantMask))
         return false;
       scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;

diff  --git a/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir b/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
index e3b072549bc04..3158916a3195c 100644
--- a/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
+++ b/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
@@ -40,9 +40,9 @@ registers:
   - { id: 7, class: b32 }
 body: |
   bb.0.entry:
-    %0 = LD_i32 0, 0, 4, 2, 32, &test_param_0, 0
+    %0 = LD_i32 0, 0, 4, 2, 32, -1, &test_param_0, 0
     %1 = CVT_f64_f32 %0, 0
-    %2 = LD_i32 0, 0, 4, 0, 32, &test_param_1, 0
+    %2 = LD_i32 0, 0, 4, 0, 32, -1, &test_param_1, 0
   ; CHECK: %3:b64 = FADD_rnf64ri %1, double 3.250000e+00
     %3 = FADD_rnf64ri %1, double 3.250000e+00
     %4 = CVT_f32_f64 %3, 5
@@ -66,9 +66,9 @@ registers:
   - { id: 7, class: b32 }
 body: |
   bb.0.entry:
-    %0 = LD_i32 0, 0, 4, 2, 32, &test2_param_0, 0
+    %0 = LD_i32 0, 0, 4, 2, 32, -1, &test2_param_0, 0
     %1 = CVT_f64_f32 %0, 0
-    %2 = LD_i32 0, 0, 4, 0, 32, &test2_param_1, 0
+    %2 = LD_i32 0, 0, 4, 0, 32, -1, &test2_param_1, 0
   ; CHECK: %3:b64 = FADD_rnf64ri %1, double 0x7FF8000000000000
     %3 = FADD_rnf64ri %1, double 0x7FF8000000000000
     %4 = CVT_f32_f64 %3, 5

diff  --git a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
index 3fac29f74125b..d219493d2b31b 100644
--- a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
+++ b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
@@ -346,19 +346,15 @@ define i32 @ld_global_v8i32(ptr addrspace(1) %ptr) {
 ; SM100-LABEL: ld_global_v8i32(
 ; SM100:       {
 ; SM100-NEXT:    .reg .b32 %r<16>;
-; SM100-NEXT:    .reg .b64 %rd<6>;
+; SM100-NEXT:    .reg .b64 %rd<2>;
 ; SM100-EMPTY:
 ; SM100-NEXT:  // %bb.0:
 ; SM100-NEXT:    ld.param.b64 %rd1, [ld_global_v8i32_param_0];
-; SM100-NEXT:    ld.global.nc.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
-; SM100-NEXT:    mov.b64 {%r1, %r2}, %rd5;
-; SM100-NEXT:    mov.b64 {%r3, %r4}, %rd4;
-; SM100-NEXT:    mov.b64 {%r5, %r6}, %rd3;
-; SM100-NEXT:    mov.b64 {%r7, %r8}, %rd2;
-; SM100-NEXT:    add.s32 %r9, %r7, %r8;
-; SM100-NEXT:    add.s32 %r10, %r5, %r6;
-; SM100-NEXT:    add.s32 %r11, %r3, %r4;
-; SM100-NEXT:    add.s32 %r12, %r1, %r2;
+; SM100-NEXT:    ld.global.nc.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    add.s32 %r9, %r1, %r2;
+; SM100-NEXT:    add.s32 %r10, %r3, %r4;
+; SM100-NEXT:    add.s32 %r11, %r5, %r6;
+; SM100-NEXT:    add.s32 %r12, %r7, %r8;
 ; SM100-NEXT:    add.s32 %r13, %r9, %r10;
 ; SM100-NEXT:    add.s32 %r14, %r11, %r12;
 ; SM100-NEXT:    add.s32 %r15, %r13, %r14;

diff  --git a/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir b/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
index 0b2d85600a2ef..4be91dfc60c6a 100644
--- a/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
+++ b/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
@@ -26,10 +26,10 @@ body:             |
   ; CHECK: bb.0.entry:
   ; CHECK-NEXT:   successors: %bb.2(0x30000000), %bb.3(0x50000000)
   ; CHECK-NEXT: {{  $}}
-  ; CHECK-NEXT:   [[LD_i32_:%[0-9]+]]:b32 = LD_i32 0, 0, 101, 3, 32, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
-  ; CHECK-NEXT:   [[LD_i64_:%[0-9]+]]:b64 = LD_i64 0, 0, 101, 3, 64, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
+  ; CHECK-NEXT:   [[LD_i32_:%[0-9]+]]:b32 = LD_i32 0, 0, 101, 3, 32, -1, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
+  ; CHECK-NEXT:   [[LD_i64_:%[0-9]+]]:b64 = LD_i64 0, 0, 101, 3, 64, -1, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
   ; CHECK-NEXT:   [[ADD64ri:%[0-9]+]]:b64 = nuw ADD64ri killed [[LD_i64_]], 2
-  ; CHECK-NEXT:   [[LD_i32_1:%[0-9]+]]:b32 = LD_i32 0, 0, 1, 3, 32, [[ADD64ri]], 0
+  ; CHECK-NEXT:   [[LD_i32_1:%[0-9]+]]:b32 = LD_i32 0, 0, 1, 3, 32, -1, [[ADD64ri]], 0
   ; CHECK-NEXT:   [[SETP_i32ri:%[0-9]+]]:b1 = SETP_i32ri [[LD_i32_]], 0, 0
   ; CHECK-NEXT:   CBranch killed [[SETP_i32ri]], %bb.2
   ; CHECK-NEXT: {{  $}}
@@ -54,10 +54,10 @@ body:             |
   bb.0.entry:
     successors: %bb.2(0x30000000), %bb.1(0x50000000)
 
-    %5:b32 = LD_i32 0, 0, 101, 3, 32, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
-    %6:b64 = LD_i64 0, 0, 101, 3, 64, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
+    %5:b32 = LD_i32 0, 0, 101, 3, 32, -1, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
+    %6:b64 = LD_i64 0, 0, 101, 3, 64, -1, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
     %0:b64 = nuw ADD64ri killed %6, 2
-    %1:b32 = LD_i32 0, 0, 1, 3, 32, %0, 0
+    %1:b32 = LD_i32 0, 0, 1, 3, 32, -1, %0, 0
     %7:b1 = SETP_i32ri %5, 0, 0
     CBranch killed %7, %bb.2
     GOTO %bb.1

diff  --git a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
new file mode 100644
index 0000000000000..8617dea310d6c
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
@@ -0,0 +1,366 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s -check-prefixes=CHECK,SM90
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK,SM100
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+
+; Different architectures are tested in this file for the following reasons:
+; - SM90 does not have 256-bit load/store instructions
+; - SM90 does not have masked store instructions
+; - SM90 does not support packed f32x2 instructions
+
+define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM90-NEXT:    .pragma "used_bytes_mask 0xf000";
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    .pragma "used_bytes_mask 0xf0f";
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT:    st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xi32(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM100-NEXT:    .pragma "used_bytes_mask 0xf0000f0f";
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT:    ret;
+  %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 32 %a, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+; Masked stores are only supported for 32-bit element types,
+; while masked loads are supported for all element types.
+define void @global_16xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_16xi16(
+; SM90:       {
+; SM90-NEXT:    .reg .b16 %rs<7>;
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_16xi16_param_0];
+; SM90-NEXT:    .pragma "used_bytes_mask 0xf000";
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    mov.b32 {%rs1, %rs2}, %r4;
+; SM90-NEXT:    .pragma "used_bytes_mask 0xf0f";
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    mov.b32 {%rs3, %rs4}, %r7;
+; SM90-NEXT:    mov.b32 {%rs5, %rs6}, %r5;
+; SM90-NEXT:    ld.param.b64 %rd2, [global_16xi16_param_1];
+; SM90-NEXT:    st.global.b16 [%rd2], %rs5;
+; SM90-NEXT:    st.global.b16 [%rd2+2], %rs6;
+; SM90-NEXT:    st.global.b16 [%rd2+8], %rs3;
+; SM90-NEXT:    st.global.b16 [%rd2+10], %rs4;
+; SM90-NEXT:    st.global.b16 [%rd2+28], %rs1;
+; SM90-NEXT:    st.global.b16 [%rd2+30], %rs2;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_16xi16(
+; SM100:       {
+; SM100-NEXT:    .reg .b16 %rs<7>;
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_16xi16_param_0];
+; SM100-NEXT:    .pragma "used_bytes_mask 0xf0000f0f";
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    mov.b32 {%rs1, %rs2}, %r8;
+; SM100-NEXT:    mov.b32 {%rs3, %rs4}, %r3;
+; SM100-NEXT:    mov.b32 {%rs5, %rs6}, %r1;
+; SM100-NEXT:    ld.param.b64 %rd2, [global_16xi16_param_1];
+; SM100-NEXT:    st.global.b16 [%rd2], %rs5;
+; SM100-NEXT:    st.global.b16 [%rd2+2], %rs6;
+; SM100-NEXT:    st.global.b16 [%rd2+8], %rs3;
+; SM100-NEXT:    st.global.b16 [%rd2+10], %rs4;
+; SM100-NEXT:    st.global.b16 [%rd2+28], %rs1;
+; SM100-NEXT:    st.global.b16 [%rd2+30], %rs2;
+; SM100-NEXT:    ret;
+  %a.load = tail call <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1) align 32 %a, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>, <16 x i16> poison)
+  tail call void @llvm.masked.store.v16i16.p1(<16 x i16> %a.load, ptr addrspace(1) align 32 %b, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>)
+  ret void
+}
+
+define void @global_8xi32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_8xi32_no_align(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_8xi32_no_align_param_0];
+; CHECK-NEXT:    ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_8xi32_no_align_param_1];
+; CHECK-NEXT:    ld.global.b32 %r2, [%rd1+8];
+; CHECK-NEXT:    ld.global.b32 %r3, [%rd1+28];
+; CHECK-NEXT:    st.global.b32 [%rd2], %r1;
+; CHECK-NEXT:    st.global.b32 [%rd2+8], %r2;
+; CHECK-NEXT:    st.global.b32 [%rd2+28], %r3;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 16 %a, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 16 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+
+define void @global_8xi32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32_invariant(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xi32_invariant_param_0];
+; SM90-NEXT:    .pragma "used_bytes_mask 0xf000";
+; SM90-NEXT:    ld.global.nc.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    .pragma "used_bytes_mask 0xf0f";
+; SM90-NEXT:    ld.global.nc.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xi32_invariant_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT:    st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xi32_invariant(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xi32_invariant_param_0];
+; SM100-NEXT:    .pragma "used_bytes_mask 0xf0000f0f";
+; SM100-NEXT:    ld.global.nc.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_invariant_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT:    ret;
+  %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 32 %a, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison), !invariant.load !0
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @global_2xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xi16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_2xi16_param_0];
+; CHECK-NEXT:    .pragma "used_bytes_mask 0x3";
+; CHECK-NEXT:    ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_2xi16_param_1];
+; CHECK-NEXT:    mov.b32 {%rs1, _}, %r1;
+; CHECK-NEXT:    st.global.b16 [%rd2], %rs1;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 4 %a, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
+  tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @global_2xi16_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xi16_invariant(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_2xi16_invariant_param_0];
+; CHECK-NEXT:    .pragma "used_bytes_mask 0x3";
+; CHECK-NEXT:    ld.global.nc.b32 %r1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_2xi16_invariant_param_1];
+; CHECK-NEXT:    mov.b32 {%rs1, _}, %r1;
+; CHECK-NEXT:    st.global.b16 [%rd2], %rs1;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 4 %a, <2 x i1> <i1 true, i1 false>, <2 x i16> poison), !invariant.load !0
+  tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @global_2xi16_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xi16_no_align(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_2xi16_no_align_param_0];
+; CHECK-NEXT:    ld.global.b16 %rs1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_2xi16_no_align_param_1];
+; CHECK-NEXT:    st.global.b16 [%rd2], %rs1;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 2 %a, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
+  tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @global_4xi8(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_4xi8(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_4xi8_param_0];
+; CHECK-NEXT:    .pragma "used_bytes_mask 0x5";
+; CHECK-NEXT:    ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_4xi8_param_1];
+; CHECK-NEXT:    st.global.b8 [%rd2], %r1;
+; CHECK-NEXT:    prmt.b32 %r2, %r1, 0, 0x7772U;
+; CHECK-NEXT:    st.global.b8 [%rd2+2], %r2;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 4 %a, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
+  tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+define void @global_4xi8_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_4xi8_invariant(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_4xi8_invariant_param_0];
+; CHECK-NEXT:    .pragma "used_bytes_mask 0x5";
+; CHECK-NEXT:    ld.global.nc.b32 %r1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_4xi8_invariant_param_1];
+; CHECK-NEXT:    st.global.b8 [%rd2], %r1;
+; CHECK-NEXT:    prmt.b32 %r2, %r1, 0, 0x7772U;
+; CHECK-NEXT:    st.global.b8 [%rd2+2], %r2;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 4 %a, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison), !invariant.load !0
+  tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+define void @global_4xi8_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_4xi8_no_align(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_4xi8_no_align_param_0];
+; CHECK-NEXT:    ld.global.b8 %rs1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_4xi8_no_align_param_1];
+; CHECK-NEXT:    ld.global.b8 %rs2, [%rd1+2];
+; CHECK-NEXT:    st.global.b8 [%rd2], %rs1;
+; CHECK-NEXT:    st.global.b8 [%rd2+2], %rs2;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 2 %a, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
+  tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+; In sm100+, we pack 2xf32 loads into a single b64 load while lowering
+define void @global_2xf32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_2xf32(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<3>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_2xf32_param_0];
+; SM90-NEXT:    .pragma "used_bytes_mask 0xf";
+; SM90-NEXT:    ld.global.v2.b32 {%r1, %r2}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_2xf32_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r1;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_2xf32(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<2>;
+; SM100-NEXT:    .reg .b64 %rd<4>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_2xf32_param_0];
+; SM100-NEXT:    .pragma "used_bytes_mask 0xf";
+; SM100-NEXT:    ld.global.b64 %rd2, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd3, [global_2xf32_param_1];
+; SM100-NEXT:    mov.b64 {%r1, _}, %rd2;
+; SM100-NEXT:    st.global.b32 [%rd3], %r1;
+; SM100-NEXT:    ret;
+  %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 8 %a, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
+  tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @global_2xf32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_2xf32_invariant(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<3>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_2xf32_invariant_param_0];
+; SM90-NEXT:    .pragma "used_bytes_mask 0xf";
+; SM90-NEXT:    ld.global.nc.v2.b32 {%r1, %r2}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_2xf32_invariant_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r1;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_2xf32_invariant(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<2>;
+; SM100-NEXT:    .reg .b64 %rd<4>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_2xf32_invariant_param_0];
+; SM100-NEXT:    .pragma "used_bytes_mask 0xf";
+; SM100-NEXT:    ld.global.nc.b64 %rd2, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd3, [global_2xf32_invariant_param_1];
+; SM100-NEXT:    mov.b64 {%r1, _}, %rd2;
+; SM100-NEXT:    st.global.b32 [%rd3], %r1;
+; SM100-NEXT:    ret;
+  %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 8 %a, <2 x i1> <i1 true, i1 false>, <2 x float> poison), !invariant.load !0
+  tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @global_2xf32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xf32_no_align(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_2xf32_no_align_param_0];
+; CHECK-NEXT:    ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_2xf32_no_align_param_1];
+; CHECK-NEXT:    st.global.b32 [%rd2], %r1;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 4 %a, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
+  tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+declare <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1), <8 x i1>, <8 x i32>)
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), <8 x i1>)
+declare <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1), <16 x i1>, <16 x i16>)
+declare void @llvm.masked.store.v16i16.p1(<16 x i16>, ptr addrspace(1), <16 x i1>)
+declare <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1), <2 x i1>, <2 x i16>)
+declare void @llvm.masked.store.v2i16.p1(<2 x i16>, ptr addrspace(1), <2 x i1>)
+declare <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1), <4 x i1>, <4 x i8>)
+declare void @llvm.masked.store.v4i8.p1(<4 x i8>, ptr addrspace(1), <4 x i1>)
+declare <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1), <2 x i1>, <2 x float>)
+declare void @llvm.masked.store.v2f32.p1(<2 x float>, ptr addrspace(1), <2 x i1>)
+!0 = !{}

diff  --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
new file mode 100644
index 0000000000000..9f23acaf93bc8
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; Confirm that a masked store with a variable mask is scalarized before lowering
+
+define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x i1> %mask) {
+; CHECK-LABEL: global_variable_mask(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<9>;
+; CHECK-NEXT:    .reg .b16 %rs<9>;
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b8 %rs1, [global_variable_mask_param_2+3];
+; CHECK-NEXT:    ld.param.b8 %rs3, [global_variable_mask_param_2+2];
+; CHECK-NEXT:    and.b16 %rs4, %rs3, 1;
+; CHECK-NEXT:    ld.param.b8 %rs5, [global_variable_mask_param_2+1];
+; CHECK-NEXT:    and.b16 %rs6, %rs5, 1;
+; CHECK-NEXT:    setp.ne.b16 %p2, %rs6, 0;
+; CHECK-NEXT:    ld.param.b8 %rs7, [global_variable_mask_param_2];
+; CHECK-NEXT:    and.b16 %rs8, %rs7, 1;
+; CHECK-NEXT:    setp.ne.b16 %p1, %rs8, 0;
+; CHECK-NEXT:    ld.param.b64 %rd5, [global_variable_mask_param_1];
+; CHECK-NEXT:    ld.param.b64 %rd6, [global_variable_mask_param_0];
+; CHECK-NEXT:    ld.global.v4.b64 {%rd1, %rd2, %rd3, %rd4}, [%rd6];
+; CHECK-NEXT:    not.pred %p5, %p1;
+; CHECK-NEXT:    @%p5 bra $L__BB0_2;
+; CHECK-NEXT:  // %bb.1: // %cond.store
+; CHECK-NEXT:    st.global.b64 [%rd5], %rd1;
+; CHECK-NEXT:  $L__BB0_2: // %else
+; CHECK-NEXT:    and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT:    setp.ne.b16 %p3, %rs4, 0;
+; CHECK-NEXT:    not.pred %p6, %p2;
+; CHECK-NEXT:    @%p6 bra $L__BB0_4;
+; CHECK-NEXT:  // %bb.3: // %cond.store1
+; CHECK-NEXT:    st.global.b64 [%rd5+8], %rd2;
+; CHECK-NEXT:  $L__BB0_4: // %else2
+; CHECK-NEXT:    setp.ne.b16 %p4, %rs2, 0;
+; CHECK-NEXT:    not.pred %p7, %p3;
+; CHECK-NEXT:    @%p7 bra $L__BB0_6;
+; CHECK-NEXT:  // %bb.5: // %cond.store3
+; CHECK-NEXT:    st.global.b64 [%rd5+16], %rd3;
+; CHECK-NEXT:  $L__BB0_6: // %else4
+; CHECK-NEXT:    not.pred %p8, %p4;
+; CHECK-NEXT:    @%p8 bra $L__BB0_8;
+; CHECK-NEXT:  // %bb.7: // %cond.store5
+; CHECK-NEXT:    st.global.b64 [%rd5+24], %rd4;
+; CHECK-NEXT:  $L__BB0_8: // %else6
+; CHECK-NEXT:    ret;
+  %a.load = load <4 x i64>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> %mask)
+  ret void
+}
+
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), <4 x i1>)

diff  --git a/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
new file mode 100644
index 0000000000000..feb7b7e0a0b39
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
@@ -0,0 +1,318 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s -check-prefixes=CHECK,SM90
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK,SM100
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; This test is based on load-store-vectors.ll,
+; and contains testing for lowering 256-bit masked vector stores
+
+; Types we are checking: i32, i64, f32, f64
+
+; Address spaces we are checking: generic, global
+; - Global is the only address space that currently supports masked stores.
+; - The generic stores will get legalized before the backend via scalarization,
+;   this file tests that even though we won't be generating them in the LSV.
+
+; 256-bit vector loads/stores are only legal for blackwell+, so on sm_90, the vectors will be split
+
+; generic address space
+
+define void @generic_8xi32(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xi32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<9>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_8xi32_param_0];
+; CHECK-NEXT:    ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT:    ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [generic_8xi32_param_1];
+; CHECK-NEXT:    st.b32 [%rd2], %r5;
+; CHECK-NEXT:    st.b32 [%rd2+8], %r7;
+; CHECK-NEXT:    st.b32 [%rd2+28], %r4;
+; CHECK-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr %a
+  tail call void @llvm.masked.store.v8i32.p0(<8 x i32> %a.load, ptr align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @generic_4xi64(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xi64(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_4xi64_param_0];
+; CHECK-NEXT:    ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT:    ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd6, [generic_4xi64_param_1];
+; CHECK-NEXT:    st.b64 [%rd6], %rd4;
+; CHECK-NEXT:    st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT:    ret;
+  %a.load = load <4 x i64>, ptr %a
+  tail call void @llvm.masked.store.v4i64.p0(<4 x i64> %a.load, ptr align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+define void @generic_8xfloat(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xfloat(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<9>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_8xfloat_param_0];
+; CHECK-NEXT:    ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT:    ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [generic_8xfloat_param_1];
+; CHECK-NEXT:    st.b32 [%rd2], %r5;
+; CHECK-NEXT:    st.b32 [%rd2+8], %r7;
+; CHECK-NEXT:    st.b32 [%rd2+28], %r4;
+; CHECK-NEXT:    ret;
+  %a.load = load <8 x float>, ptr %a
+  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %a.load, ptr align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @generic_4xdouble(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xdouble(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_4xdouble_param_0];
+; CHECK-NEXT:    ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT:    ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd6, [generic_4xdouble_param_1];
+; CHECK-NEXT:    st.b64 [%rd6], %rd4;
+; CHECK-NEXT:    st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT:    ret;
+  %a.load = load <4 x double>, ptr %a
+  tail call void @llvm.masked.store.v4f64.p0(<4 x double> %a.load, ptr align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+; global address space
+
+define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT:    st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xi32(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @global_4xi64(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xi64(
+; SM90:       {
+; SM90-NEXT:    .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM90-NEXT:    ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT:    ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM90-NEXT:    st.global.b64 [%rd6], %rd4;
+; SM90-NEXT:    st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_4xi64(
+; SM100:       {
+; SM100-NEXT:    .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM100-NEXT:    ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM100-NEXT:    st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT:    ret;
+  %a.load = load <4 x i64>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+define void @global_8xfloat(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xfloat(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT:    st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xfloat(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT:    ret;
+  %a.load = load <8 x float>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8f32.p1(<8 x float> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @global_4xdouble(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xdouble(
+; SM90:       {
+; SM90-NEXT:    .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM90-NEXT:    ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT:    ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM90-NEXT:    st.global.b64 [%rd6], %rd4;
+; SM90-NEXT:    st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_4xdouble(
+; SM100:       {
+; SM100-NEXT:    .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM100-NEXT:    ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM100-NEXT:    st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT:    ret;
+  %a.load = load <4 x double>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v4f64.p1(<4 x double> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+; edge cases
+define void @global_8xi32_all_mask_on(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32_all_mask_on(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1+16];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
+; SM90-NEXT:    st.global.v4.b32 [%rd2+16], {%r5, %r6, %r7, %r8};
+; SM90-NEXT:    st.global.v4.b32 [%rd2], {%r1, %r2, %r3, %r4};
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xi32_all_mask_on(
+; SM100:       {
+; SM100-NEXT:    .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM100-NEXT:    ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd6, [global_8xi32_all_mask_on_param_1];
+; SM100-NEXT:    st.global.v4.b64 [%rd6], {%rd2, %rd3, %rd4, %rd5};
+; SM100-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+  ret void
+}
+
+define void @global_8xi32_all_mask_off(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_8xi32_all_mask_off(
+; CHECK:       {
+; CHECK-EMPTY:
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>)
+  ret void
+}
+
+; This is an example pattern for the LSV's output of these masked stores
+define void @vectorizerOutput(ptr addrspace(1) %in, ptr addrspace(1) %out) {
+; SM90-LABEL: vectorizerOutput(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+4], %r6;
+; SM90-NEXT:    st.global.b32 [%rd2+12], %r8;
+; SM90-NEXT:    st.global.b32 [%rd2+16], %r1;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: vectorizerOutput(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, %r2, _, %r4, %r5, _, _, _};
+; SM100-NEXT:    ret;
+  %1 = load <8 x i32>, ptr addrspace(1) %in, align 32
+  %load05 = extractelement <8 x i32> %1, i32 0
+  %load16 = extractelement <8 x i32> %1, i32 1
+  %load38 = extractelement <8 x i32> %1, i32 3
+  %load49 = extractelement <8 x i32> %1, i32 4
+  %2 = insertelement <8 x i32> poison, i32 %load05, i32 0
+  %3 = insertelement <8 x i32> %2, i32 %load16, i32 1
+  %4 = insertelement <8 x i32> %3, i32 poison, i32 2
+  %5 = insertelement <8 x i32> %4, i32 %load38, i32 3
+  %6 = insertelement <8 x i32> %5, i32 %load49, i32 4
+  %7 = insertelement <8 x i32> %6, i32 poison, i32 5
+  %8 = insertelement <8 x i32> %7, i32 poison, i32 6
+  %9 = insertelement <8 x i32> %8, i32 poison, i32 7
+  call void @llvm.masked.store.v8i32.p1(<8 x i32> %9, ptr addrspace(1) align 32 %out, <8 x i1> <i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false>)
+  ret void
+}
+
+declare void @llvm.masked.store.v8i32.p0(<8 x i32>, ptr, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p0(<4 x i64>, ptr, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p0(<4 x double>, ptr, <4 x i1>)
+
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), <8 x i1>)
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), <4 x i1>)
+declare void @llvm.masked.store.v8f32.p1(<8 x float>, ptr addrspace(1), <8 x i1>)
+declare void @llvm.masked.store.v4f64.p1(<4 x double>, ptr addrspace(1), <4 x i1>)

diff  --git a/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir b/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir
index dfc84177fb0e6..a84b7fcd33836 100644
--- a/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir
+++ b/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir
@@ -77,7 +77,7 @@ constants:       []
 machineFunctionInfo: {}
 body:             |
   bb.0:
-    %0:b32, %1:b32, %2:b32, %3:b32 = LDV_i32_v4 0, 0, 101, 3, 32, &retval0, 0 :: (load (s128), addrspace 101)
+    %0:b32, %1:b32, %2:b32, %3:b32 = LDV_i32_v4 0, 0, 101, 3, 32, -1, &retval0, 0 :: (load (s128), addrspace 101)
     ; CHECK-NOT: ProxyReg
     %4:b32 = ProxyRegB32 killed %0
     %5:b32 = ProxyRegB32 killed %1
@@ -86,7 +86,7 @@ body:             |
     ; CHECK: STV_i32_v4 killed %0, killed %1, killed %2, killed %3
     STV_i32_v4 killed %4, killed %5, killed %6, killed %7, 0, 0, 101, 32, &func_retval0, 0 :: (store (s128), addrspace 101)
 
-    %8:b32 = LD_i32 0, 0, 101, 3, 32, &retval0, 0 :: (load (s32), addrspace 101)
+    %8:b32 = LD_i32 0, 0, 101, 3, 32, -1, &retval0, 0 :: (load (s32), addrspace 101)
     ; CHECK-NOT: ProxyReg
     %9:b32 = ProxyRegB32 killed %8
     %10:b32 = ProxyRegB32 killed %9


        


More information about the llvm-commits mailing list