[llvm] [NVPTX] Lower LLVM masked vector loads and stores to PTX (PR #159387)
    Drew Kersnar via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Wed Oct 22 10:04:55 PDT 2025
    
    
  
https://github.com/dakersnar updated https://github.com/llvm/llvm-project/pull/159387
>From 069c336c4b510e3b59a98a4b3e3386cd244ea77a Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Wed, 17 Sep 2025 15:30:38 +0000
Subject: [PATCH 1/9] [NVPTX] Lower LLVM masked vector stores to PTX using the
 new sink symbol syntax
---
 .../llvm/Analysis/TargetTransformInfo.h       |   8 +-
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   2 +-
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   4 +-
 .../AArch64/AArch64TargetTransformInfo.h      |   2 +-
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h  |   2 +-
 .../Hexagon/HexagonTargetTransformInfo.cpp    |   2 +-
 .../Hexagon/HexagonTargetTransformInfo.h      |   2 +-
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp   |  10 +
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.h     |   2 +
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |  89 ++++-
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |  10 +-
 .../Target/NVPTX/NVPTXTargetTransformInfo.cpp |  26 ++
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   |   3 +
 .../Target/RISCV/RISCVTargetTransformInfo.h   |   2 +-
 llvm/lib/Target/VE/VETargetTransformInfo.h    |   2 +-
 .../lib/Target/X86/X86TargetTransformInfo.cpp |   2 +-
 llvm/lib/Target/X86/X86TargetTransformInfo.h  |   2 +-
 .../Scalar/ScalarizeMaskedMemIntrin.cpp       |   3 +-
 .../NVPTX/masked-store-variable-mask.ll       |  56 +++
 .../CodeGen/NVPTX/masked-store-vectors-256.ll | 319 ++++++++++++++++++
 20 files changed, 530 insertions(+), 18 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
 create mode 100644 llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 5d3b233ed6b6a..f0d30913e1a52 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -813,9 +813,13 @@ class TargetTransformInfo {
   LLVM_ABI AddressingModeKind
   getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
 
-  /// Return true if the target supports masked store.
+  /// Return true if the target supports masked store. A value of false for
+  /// IsMaskConstant indicates that the mask could either be variable or
+  /// constant. This is for targets that only support masked store with a
+  /// constant mask.
   LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                   unsigned AddressSpace) const;
+                                   unsigned AddressSpace,
+                                   bool IsMaskConstant = false) const;
   /// Return true if the target supports masked load.
   LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                                   unsigned AddressSpace) const;
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 4cd607c0d0c8d..06df6ca58b122 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,7 @@ class TargetTransformInfoImplBase {
   }
 
   virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                  unsigned AddressSpace) const {
+                                  unsigned AddressSpace, bool IsMaskConstant) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index bf62623099a97..9310350b48585 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -468,8 +468,8 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
 }
 
 bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
-                                             unsigned AddressSpace) const {
-  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
+                                             unsigned AddressSpace, bool IsMaskConstant) const {
+  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
 }
 
 bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index fe2e849258e3f..e40631d88748c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 0810c5532ed91..ee4f72552d90d 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -190,7 +190,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
                          unsigned AddressSpace) const override;
 
   bool isLegalMaskedStore(Type *DataTy, Align Alignment,
-                          unsigned AddressSpace) const override {
+                          unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
   }
 
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 171e2949366ad..c989bf77a9d51 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,7 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
 }
 
 bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
-                                        unsigned /*AddressSpace*/) const {
+                                        unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index dbf16c99c314c..e2674bb9cdad7 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,7 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace) const override;
+                          unsigned AddressSpace, bool IsMaskConstant) const override;
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
 
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 77913f27838e2..c8717876157c5 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -395,6 +395,16 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
   }
 }
 
+void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
+                                                 raw_ostream &O,
+                                                 const char *Modifier) {
+  const MCOperand &Op = MI->getOperand(OpNum);
+  if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
+    O << "_";
+  else
+    printOperand(MI, OpNum, O);
+}
+
 void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
                                       raw_ostream &O) {
   int64_t Imm = MI->getOperand(OpNum).getImm();
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 92155b01464e8..d373668aa591f 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,6 +46,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
                     StringRef Modifier = {});
   void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
                        StringRef Modifier = {});
+  void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
+                                 const char *Modifier = nullptr);
   void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
   void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
   void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 2f1a7ad2d401f..e52854c3b5627 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -769,7 +769,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
   for (MVT VT : MVT::fixedlen_vector_valuetypes())
     if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
-      setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
+      setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
 
   // Custom legalization for LDU intrinsics.
   // TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -3181,6 +3181,87 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
   return Or;
 }
 
+static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
+  SDNode *N = Op.getNode();
+
+  SDValue Chain = N->getOperand(0);
+  SDValue Val = N->getOperand(1);
+  SDValue BasePtr = N->getOperand(2);
+  SDValue Offset = N->getOperand(3);
+  SDValue Mask = N->getOperand(4);
+
+  SDLoc DL(N);
+  EVT ValVT = Val.getValueType();
+  MemSDNode *MemSD = cast<MemSDNode>(N);
+  assert(ValVT.isVector() && "Masked vector store must have vector type");
+  assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
+         "Unexpected alignment for masked store");
+
+  unsigned Opcode = 0;
+  switch (ValVT.getSimpleVT().SimpleTy) {
+  default:
+    llvm_unreachable("Unexpected masked vector store type");
+  case MVT::v4i64:
+  case MVT::v4f64: {
+    Opcode = NVPTXISD::StoreV4;
+    break;
+  }
+  case MVT::v8i32:
+  case MVT::v8f32: {
+    Opcode = NVPTXISD::StoreV8;
+    break;
+  }
+  }
+
+  SmallVector<SDValue, 8> Ops;
+
+  // Construct the new SDNode. First operand is the chain.
+  Ops.push_back(Chain);
+
+  // The next N operands are the values to store. Encode the mask into the
+  // values using the sentinel register 0 to represent a masked-off element.
+  assert(Mask.getValueType().isVector() &&
+         Mask.getValueType().getVectorElementType() == MVT::i1 &&
+         "Mask must be a vector of i1");
+  assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
+         "Mask expected to be a BUILD_VECTOR");
+  assert(Mask.getValueType().getVectorNumElements() ==
+             ValVT.getVectorNumElements() &&
+         "Mask size must be the same as the vector size");
+  for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
+    assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
+           "Mask elements must be constants");
+    if (Mask->getConstantOperandVal(I) == 0) {
+      // Append a sentinel register 0 to the Ops vector to represent a masked
+      // off element, this will be handled in tablegen
+      Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
+                                    ValVT.getVectorElementType()));
+    } else {
+      // Extract the element from the vector to store
+      SDValue ExtVal =
+          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(),
+                      Val, DAG.getIntPtrConstant(I, DL));
+      Ops.push_back(ExtVal);
+    }
+  }
+
+  // Next, the pointer operand.
+  Ops.push_back(BasePtr);
+
+  // Finally, the offset operand. We expect this to always be undef, and it will
+  // be ignored in lowering, but to mirror the handling of the other vector
+  // store instructions we include it in the new SDNode.
+  assert(Offset.getOpcode() == ISD::UNDEF &&
+         "Offset operand expected to be undef");
+  Ops.push_back(Offset);
+
+  SDValue NewSt =
+      DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+                              MemSD->getMemoryVT(), MemSD->getMemOperand());
+
+  return NewSt;
+}
+
 SDValue
 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -3217,6 +3298,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerVECREDUCE(Op, DAG);
   case ISD::STORE:
     return LowerSTORE(Op, DAG);
+  case ISD::MSTORE: {
+    assert(STI.has256BitVectorLoadStore(
+               cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
+           "Masked store vector not supported on subtarget.");
+    return lowerMSTORE(Op, DAG);
+  }
   case ISD::LOAD:
     return LowerLOAD(Op, DAG);
   case ISD::SHL_PARTS:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index dfde0cca0f00c..a0d5b09253c32 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1575,6 +1575,10 @@ def ADDR : Operand<pAny> {
   let MIOperandInfo = (ops ADDR_base, i32imm);
 }
 
+def RegOrSink : Operand<Any> {
+  let PrintMethod = "printRegisterOrSinkSymbol";
+}
+
 def AtomicCode : Operand<i32> {
   let PrintMethod = "printAtomicCode";
 }
@@ -1881,7 +1885,7 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
     "\t[$addr], {{$src1, $src2}};">;
   def _v4 : NVPTXInst<
     (outs),
-    (ins O:$src1, O:$src2, O:$src3, O:$src4,
+    (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
          AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
          ADDR:$addr),
     "st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth "
@@ -1889,8 +1893,8 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
   if support_v8 then
     def _v8 : NVPTXInst<
       (outs),
-      (ins O:$src1, O:$src2, O:$src3, O:$src4,
-           O:$src5, O:$src6, O:$src7, O:$src8,
+      (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
+           RegOrSink:$src5, RegOrSink:$src6, RegOrSink:$src7, RegOrSink:$src8,
            AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
            ADDR:$addr),
       "st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth "
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 4029e143ae2a4..cd90ae1601e1a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -592,6 +592,32 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
   return nullptr;
 }
 
+bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
+                                      unsigned AddrSpace, bool IsMaskConstant) const {
+
+  if (!IsMaskConstant)
+    return false;
+  
+  //  We currently only support this feature for 256-bit vectors, so the
+  //  alignment must be at least 32
+  if (Alignment < 32)
+    return false;
+
+  if (!ST->has256BitVectorLoadStore(AddrSpace))
+    return false;
+
+  auto *VTy = dyn_cast<FixedVectorType>(DataTy);
+  if (!VTy)
+    return false;
+
+  auto *ScalarTy = VTy->getScalarType();
+  if ((ScalarTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
+      (ScalarTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4))
+    return true;
+
+  return false;
+}
+
 unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
   // 256 bit loads/stores are currently only supported for global address space
   if (ST->has256BitVectorLoadStore(AddrSpace))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 78eb751cf3c2e..f06b94ba6ce80 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,6 +181,9 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                   Intrinsic::ID IID) const override;
 
+  bool isLegalMaskedStore(Type *DataType, Align Alignment,
+                          unsigned AddrSpace, bool IsMaskConstant) const override;
+
   unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
 
   Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 6886e8964e29e..2d4f8d7d6a6ec 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -290,7 +290,7 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 5c0ddca62c761..4971d9148b747 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,7 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/) const override {
+                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 3d8d0a236a3c1..b16a2a593df03 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,7 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
 }
 
 bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
-                                    unsigned AddressSpace) const {
+                                    unsigned AddressSpace, bool IsMaskConstant) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 133b3668a46c8..7f6ff65d427ed 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,7 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace) const override;
+                          unsigned AddressSpace, bool IsMaskConstant = false) const override;
   bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
   bool isLegalNTStore(Type *DataType, Align Alignment) const override;
   bool isLegalBroadcastLoad(Type *ElementTy,
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 146e7d1047dd0..2cb334d8d6952 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1132,7 +1132,8 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getArgOperand(0)->getType(),
               CI->getParamAlign(1).valueOrOne(),
               cast<PointerType>(CI->getArgOperand(1)->getType())
-                  ->getAddressSpace()))
+                  ->getAddressSpace(),
+              isConstantIntVector(CI->getArgOperand(3))))
         return false;
       scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
new file mode 100644
index 0000000000000..7d8f65b25bb02
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; Confirm that a masked store with a variable mask is scalarized before lowering
+
+define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x i1> %mask) {
+; CHECK-LABEL: global_variable_mask(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<9>;
+; CHECK-NEXT:    .reg .b16 %rs<9>;
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b8 %rs1, [global_variable_mask_param_2+3];
+; CHECK-NEXT:    ld.param.b8 %rs3, [global_variable_mask_param_2+2];
+; CHECK-NEXT:    and.b16 %rs4, %rs3, 1;
+; CHECK-NEXT:    ld.param.b8 %rs5, [global_variable_mask_param_2+1];
+; CHECK-NEXT:    and.b16 %rs6, %rs5, 1;
+; CHECK-NEXT:    setp.ne.b16 %p2, %rs6, 0;
+; CHECK-NEXT:    ld.param.b8 %rs7, [global_variable_mask_param_2];
+; CHECK-NEXT:    and.b16 %rs8, %rs7, 1;
+; CHECK-NEXT:    setp.ne.b16 %p1, %rs8, 0;
+; CHECK-NEXT:    ld.param.b64 %rd5, [global_variable_mask_param_1];
+; CHECK-NEXT:    ld.param.b64 %rd6, [global_variable_mask_param_0];
+; CHECK-NEXT:    ld.global.v4.b64 {%rd1, %rd2, %rd3, %rd4}, [%rd6];
+; CHECK-NEXT:    not.pred %p5, %p1;
+; CHECK-NEXT:    @%p5 bra $L__BB0_2;
+; CHECK-NEXT:  // %bb.1: // %cond.store
+; CHECK-NEXT:    st.global.b64 [%rd5], %rd1;
+; CHECK-NEXT:  $L__BB0_2: // %else
+; CHECK-NEXT:    and.b16 %rs2, %rs1, 1;
+; CHECK-NEXT:    setp.ne.b16 %p3, %rs4, 0;
+; CHECK-NEXT:    not.pred %p6, %p2;
+; CHECK-NEXT:    @%p6 bra $L__BB0_4;
+; CHECK-NEXT:  // %bb.3: // %cond.store1
+; CHECK-NEXT:    st.global.b64 [%rd5+8], %rd2;
+; CHECK-NEXT:  $L__BB0_4: // %else2
+; CHECK-NEXT:    setp.ne.b16 %p4, %rs2, 0;
+; CHECK-NEXT:    not.pred %p7, %p3;
+; CHECK-NEXT:    @%p7 bra $L__BB0_6;
+; CHECK-NEXT:  // %bb.5: // %cond.store3
+; CHECK-NEXT:    st.global.b64 [%rd5+16], %rd3;
+; CHECK-NEXT:  $L__BB0_6: // %else4
+; CHECK-NEXT:    not.pred %p8, %p4;
+; CHECK-NEXT:    @%p8 bra $L__BB0_8;
+; CHECK-NEXT:  // %bb.7: // %cond.store5
+; CHECK-NEXT:    st.global.b64 [%rd5+24], %rd4;
+; CHECK-NEXT:  $L__BB0_8: // %else6
+; CHECK-NEXT:    ret;
+  %a.load = load <4 x i64>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> %mask)
+  ret void
+}
+
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
new file mode 100644
index 0000000000000..0935bf80b04be
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
@@ -0,0 +1,319 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s -check-prefixes=CHECK,SM90
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK,SM100
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; This test is based on load-store-vectors.ll,
+; and contains testing for lowering 256-bit masked vector stores
+
+; Types we are checking: i32, i64, f32, f64
+
+; Address spaces we are checking: generic, global
+; - Global is the only address space that currently supports masked stores.
+; - The generic stores will get legalized before the backend via scalarization,
+;   this file tests that even though we won't be generating them in the LSV.
+
+; 256-bit vector loads/stores are only legal for blackwell+, so on sm_90, the vectors will be split
+
+; generic address space
+
+define void @generic_8xi32(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xi32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<9>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_8xi32_param_0];
+; CHECK-NEXT:    ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT:    ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [generic_8xi32_param_1];
+; CHECK-NEXT:    st.b32 [%rd2], %r5;
+; CHECK-NEXT:    st.b32 [%rd2+8], %r7;
+; CHECK-NEXT:    st.b32 [%rd2+28], %r4;
+; CHECK-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr %a
+  tail call void @llvm.masked.store.v8i32.p0(<8 x i32> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @generic_4xi64(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xi64(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_4xi64_param_0];
+; CHECK-NEXT:    ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT:    ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd6, [generic_4xi64_param_1];
+; CHECK-NEXT:    st.b64 [%rd6], %rd4;
+; CHECK-NEXT:    st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT:    ret;
+  %a.load = load <4 x i64>, ptr %a
+  tail call void @llvm.masked.store.v4i64.p0(<4 x i64> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+define void @generic_8xfloat(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_8xfloat(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<9>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_8xfloat_param_0];
+; CHECK-NEXT:    ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; CHECK-NEXT:    ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [generic_8xfloat_param_1];
+; CHECK-NEXT:    st.b32 [%rd2], %r5;
+; CHECK-NEXT:    st.b32 [%rd2+8], %r7;
+; CHECK-NEXT:    st.b32 [%rd2+28], %r4;
+; CHECK-NEXT:    ret;
+  %a.load = load <8 x float>, ptr %a
+  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @generic_4xdouble(ptr %a, ptr %b) {
+; CHECK-LABEL: generic_4xdouble(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [generic_4xdouble_param_0];
+; CHECK-NEXT:    ld.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; CHECK-NEXT:    ld.v2.b64 {%rd4, %rd5}, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd6, [generic_4xdouble_param_1];
+; CHECK-NEXT:    st.b64 [%rd6], %rd4;
+; CHECK-NEXT:    st.b64 [%rd6+16], %rd2;
+; CHECK-NEXT:    ret;
+  %a.load = load <4 x double>, ptr %a
+  tail call void @llvm.masked.store.v4f64.p0(<4 x double> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+; global address space
+
+define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT:    st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xi32(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @global_4xi64(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xi64(
+; SM90:       {
+; SM90-NEXT:    .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM90-NEXT:    ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT:    ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM90-NEXT:    st.global.b64 [%rd6], %rd4;
+; SM90-NEXT:    st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_4xi64(
+; SM100:       {
+; SM100-NEXT:    .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_4xi64_param_0];
+; SM100-NEXT:    ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd6, [global_4xi64_param_1];
+; SM100-NEXT:    st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT:    ret;
+  %a.load = load <4 x i64>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+define void @global_8xfloat(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xfloat(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT:    st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xfloat(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xfloat_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xfloat_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT:    ret;
+  %a.load = load <8 x float>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8f32.p1(<8 x float> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @global_4xdouble(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_4xdouble(
+; SM90:       {
+; SM90-NEXT:    .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM90-NEXT:    ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16];
+; SM90-NEXT:    ld.global.v2.b64 {%rd4, %rd5}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM90-NEXT:    st.global.b64 [%rd6], %rd4;
+; SM90-NEXT:    st.global.b64 [%rd6+16], %rd2;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_4xdouble(
+; SM100:       {
+; SM100-NEXT:    .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_4xdouble_param_0];
+; SM100-NEXT:    ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd6, [global_4xdouble_param_1];
+; SM100-NEXT:    st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
+; SM100-NEXT:    ret;
+  %a.load = load <4 x double>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v4f64.p1(<4 x double> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+; edge cases
+define void @global_8xi32_all_mask_on(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32_all_mask_on(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1+16];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
+; SM90-NEXT:    st.global.v4.b32 [%rd2+16], {%r5, %r6, %r7, %r8};
+; SM90-NEXT:    st.global.v4.b32 [%rd2], {%r1, %r2, %r3, %r4};
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xi32_all_mask_on(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8};
+; SM100-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+  ret void
+}
+
+define void @global_8xi32_all_mask_off(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_8xi32_all_mask_off(
+; CHECK:       {
+; CHECK-EMPTY:
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ret;
+  %a.load = load <8 x i32>, ptr addrspace(1) %a
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>)
+  ret void
+}
+
+; This is an example pattern for the LSV's output of these masked stores
+define void @vectorizerOutput(ptr addrspace(1) %in, ptr addrspace(1) %out) {
+; SM90-LABEL: vectorizerOutput(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+4], %r6;
+; SM90-NEXT:    st.global.b32 [%rd2+12], %r8;
+; SM90-NEXT:    st.global.b32 [%rd2+16], %r1;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: vectorizerOutput(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [vectorizerOutput_param_0];
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [vectorizerOutput_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, %r2, _, %r4, %r5, _, _, _};
+; SM100-NEXT:    ret;
+  %1 = load <8 x i32>, ptr addrspace(1) %in, align 32
+  %load05 = extractelement <8 x i32> %1, i32 0
+  %load16 = extractelement <8 x i32> %1, i32 1
+  %load38 = extractelement <8 x i32> %1, i32 3
+  %load49 = extractelement <8 x i32> %1, i32 4
+  %2 = insertelement <8 x i32> poison, i32 %load05, i32 0
+  %3 = insertelement <8 x i32> %2, i32 %load16, i32 1
+  %4 = insertelement <8 x i32> %3, i32 poison, i32 2
+  %5 = insertelement <8 x i32> %4, i32 %load38, i32 3
+  %6 = insertelement <8 x i32> %5, i32 %load49, i32 4
+  %7 = insertelement <8 x i32> %6, i32 poison, i32 5
+  %8 = insertelement <8 x i32> %7, i32 poison, i32 6
+  %9 = insertelement <8 x i32> %8, i32 poison, i32 7
+  call void @llvm.masked.store.v8i32.p1(<8 x i32> %9, ptr addrspace(1) %out, i32 32, <8 x i1> <i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false>)
+  ret void
+}
+
+declare void @llvm.masked.store.v8i32.p0(<8 x i32>, ptr, i32, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p0(<4 x i64>, ptr, i32, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr, i32, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p0(<4 x double>, ptr, i32, <4 x i1>)
+
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), i32, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p1(<8 x float>, ptr addrspace(1), i32, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p1(<4 x double>, ptr addrspace(1), i32, <4 x i1>)
>From 68ab3381409e56182a65a8a0f65af03b59704d91 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Wed, 17 Sep 2025 15:52:13 +0000
Subject: [PATCH 2/9] Clang format
---
 llvm/include/llvm/Analysis/TargetTransformInfoImpl.h   | 3 ++-
 llvm/lib/Analysis/TargetTransformInfo.cpp              | 6 ++++--
 llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h   | 3 ++-
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h           | 4 ++--
 llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp | 3 ++-
 llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h   | 3 ++-
 llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp     | 5 +++--
 llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h       | 4 ++--
 llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h       | 3 ++-
 llvm/lib/Target/VE/VETargetTransformInfo.h             | 3 ++-
 llvm/lib/Target/X86/X86TargetTransformInfo.cpp         | 3 ++-
 llvm/lib/Target/X86/X86TargetTransformInfo.h           | 3 ++-
 12 files changed, 27 insertions(+), 16 deletions(-)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 06df6ca58b122..3d8b6d918d7fa 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,8 @@ class TargetTransformInfoImplBase {
   }
 
   virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                  unsigned AddressSpace, bool IsMaskConstant) const {
+                                  unsigned AddressSpace,
+                                  bool IsMaskConstant) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 9310350b48585..7d02ea8ca5f96 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -468,8 +468,10 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
 }
 
 bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
-                                             unsigned AddressSpace, bool IsMaskConstant) const {
-  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
+                                             unsigned AddressSpace,
+                                             bool IsMaskConstant) const {
+  return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace,
+                                     IsMaskConstant);
 }
 
 bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index e40631d88748c..669d9e2ae1dad 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+                          unsigned /*AddressSpace*/,
+                          bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index ee4f72552d90d..3c5b2195d8dcc 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -189,8 +189,8 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
   bool isLegalMaskedLoad(Type *DataTy, Align Alignment,
                          unsigned AddressSpace) const override;
 
-  bool isLegalMaskedStore(Type *DataTy, Align Alignment,
-                          unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
+  bool isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
+                          bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
   }
 
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index c989bf77a9d51..74df572ef1521 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,8 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
 }
 
 bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
-                                        unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
+                                        unsigned /*AddressSpace*/,
+                                        bool /*IsMaskConstant*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index e2674bb9cdad7..195cc616347b0 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,8 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace, bool IsMaskConstant) const override;
+                          unsigned AddressSpace,
+                          bool IsMaskConstant) const override;
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index cd90ae1601e1a..74c993ce6d7ea 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -593,11 +593,12 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
 }
 
 bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
-                                      unsigned AddrSpace, bool IsMaskConstant) const {
+                                      unsigned AddrSpace,
+                                      bool IsMaskConstant) const {
 
   if (!IsMaskConstant)
     return false;
-  
+
   //  We currently only support this feature for 256-bit vectors, so the
   //  alignment must be at least 32
   if (Alignment < 32)
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index f06b94ba6ce80..80cf9772f76c5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,8 +181,8 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                   Intrinsic::ID IID) const override;
 
-  bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddrSpace, bool IsMaskConstant) const override;
+  bool isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddrSpace,
+                          bool IsMaskConstant) const override;
 
   unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
 
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 2d4f8d7d6a6ec..8dfa748963a38 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -290,7 +290,8 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+                          unsigned /*AddressSpace*/,
+                          bool /*IsMaskConstant*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 4971d9148b747..3f33760a5ead9 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,8 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
+                          unsigned /*AddressSpace*/,
+                          bool /*IsMaskConstant*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index b16a2a593df03..851009f89e196 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,8 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
 }
 
 bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
-                                    unsigned AddressSpace, bool IsMaskConstant) const {
+                                    unsigned AddressSpace,
+                                    bool IsMaskConstant) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 7f6ff65d427ed..c7c3d93600878 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,8 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
                          unsigned AddressSpace) const override;
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace, bool IsMaskConstant = false) const override;
+                          unsigned AddressSpace,
+                          bool IsMaskConstant = false) const override;
   bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
   bool isLegalNTStore(Type *DataType, Align Alignment) const override;
   bool isLegalBroadcastLoad(Type *ElementTy,
>From 8e7140ac245dd9d67e73686c8ac4109ac3b06299 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Fri, 17 Oct 2025 21:35:04 +0000
Subject: [PATCH 3/9] Update TTI to add enum value representing MaskKind to
 isLegalMaskedLoad/Store
---
 .../llvm/Analysis/TargetTransformInfo.h       | 22 +++++++++++--------
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  5 +++--
 llvm/lib/Analysis/TargetTransformInfo.cpp     | 10 +++++----
 .../AArch64/AArch64TargetTransformInfo.h      |  5 +++--
 .../lib/Target/ARM/ARMTargetTransformInfo.cpp |  3 ++-
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h  |  8 +++----
 .../Hexagon/HexagonTargetTransformInfo.cpp    |  5 +++--
 .../Hexagon/HexagonTargetTransformInfo.h      |  6 ++---
 .../Target/NVPTX/NVPTXTargetTransformInfo.cpp | 16 +++++++++++---
 .../Target/NVPTX/NVPTXTargetTransformInfo.h   |  5 ++++-
 .../Target/RISCV/RISCVTargetTransformInfo.h   |  5 +++--
 llvm/lib/Target/VE/VETargetTransformInfo.h    | 11 +++++-----
 .../lib/Target/X86/X86TargetTransformInfo.cpp |  5 +++--
 llvm/lib/Target/X86/X86TargetTransformInfo.h  | 13 ++++++-----
 .../Scalar/ScalarizeMaskedMemIntrin.cpp       |  9 ++++++--
 15 files changed, 81 insertions(+), 47 deletions(-)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f0d30913e1a52..c61c1a9a9c622 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -813,16 +813,20 @@ class TargetTransformInfo {
   LLVM_ABI AddressingModeKind
   getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
 
-  /// Return true if the target supports masked store. A value of false for
-  /// IsMaskConstant indicates that the mask could either be variable or
-  /// constant. This is for targets that only support masked store with a
-  /// constant mask.
-  LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                                   unsigned AddressSpace,
-                                   bool IsMaskConstant = false) const;
+  /// Some targets only support masked load/store with a constant mask.
+  enum MaskKind {
+    VariableOrConstantMask,
+    ConstantMask,
+  };
+
+  /// Return true if the target supports masked store.
+  LLVM_ABI bool
+  isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace,
+                     MaskKind MaskKind = VariableOrConstantMask) const;
   /// Return true if the target supports masked load.
-  LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                                  unsigned AddressSpace) const;
+  LLVM_ABI bool
+  isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+                    MaskKind MaskKind = VariableOrConstantMask) const;
 
   /// Return true if the target supports nontemporal store.
   LLVM_ABI bool isLegalNTStore(Type *DataType, Align Alignment) const;
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 3d8b6d918d7fa..653ae750dec52 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -310,12 +310,13 @@ class TargetTransformInfoImplBase {
 
   virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
                                   unsigned AddressSpace,
-                                  bool IsMaskConstant) const {
+                                  TTI::MaskKind MaskKind) const {
     return false;
   }
 
   virtual bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                                 unsigned AddressSpace) const {
+                                 unsigned AddressSpace,
+                                 TTI::MaskKind MaskKind) const {
     return false;
   }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 7d02ea8ca5f96..017e9b8d9fe00 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -469,14 +469,16 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
 
 bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
                                              unsigned AddressSpace,
-                                             bool IsMaskConstant) const {
+                                             TTI::MaskKind MaskKind) const {
   return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace,
-                                     IsMaskConstant);
+                                     MaskKind);
 }
 
 bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
-                                            unsigned AddressSpace) const {
-  return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace);
+                                            unsigned AddressSpace,
+                                            TTI::MaskKind MaskKind) const {
+  return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace,
+                                    MaskKind);
 }
 
 bool TargetTransformInfo::isLegalNTStore(Type *DataType,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 669d9e2ae1dad..4fda311738a90 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -316,13 +316,14 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
   }
 
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned /*AddressSpace*/) const override {
+                         unsigned /*AddressSpace*/,
+                         TTI::MaskKind /*MaskKind*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
                           unsigned /*AddressSpace*/,
-                          bool /*IsMaskConstant*/) const override {
+                          TTI::MaskKind /*MaskKind*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 9b250e6cac3ab..7ee7e9a953f5c 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1125,7 +1125,8 @@ bool ARMTTIImpl::isProfitableLSRChainElement(Instruction *I) const {
 }
 
 bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
-                                   unsigned /*AddressSpace*/) const {
+                                   unsigned /*AddressSpace*/,
+                                   TTI::MaskKind /*MaskKind*/) const {
   if (!EnableMaskedLoadStores || !ST->hasMVEIntegerOps())
     return false;
 
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 3c5b2195d8dcc..401f5ea044dab 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -186,12 +186,12 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
 
   bool isProfitableLSRChainElement(Instruction *I) const override;
 
-  bool isLegalMaskedLoad(Type *DataTy, Align Alignment,
-                         unsigned AddressSpace) const override;
+  bool isLegalMaskedLoad(Type *DataTy, Align Alignment, unsigned AddressSpace,
+                         TTI::MaskKind MaskKind) const override;
 
   bool isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
-                          bool /*IsMaskConstant*/) const override {
-    return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
+                          TTI::MaskKind MaskKind) const override {
+    return isLegalMaskedLoad(DataTy, Alignment, AddressSpace, MaskKind);
   }
 
   bool forceScalarizeMaskedGather(VectorType *VTy,
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 74df572ef1521..da006ff037eeb 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -342,14 +342,15 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
 
 bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
                                         unsigned /*AddressSpace*/,
-                                        bool /*IsMaskConstant*/) const {
+                                        TTI::MaskKind /*MaskKind*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
 }
 
 bool HexagonTTIImpl::isLegalMaskedLoad(Type *DataType, Align /*Alignment*/,
-                                       unsigned /*AddressSpace*/) const {
+                                       unsigned /*AddressSpace*/,
+                                       TTI::MaskKind /*MaskKind*/) const {
   // This function is called from scalarize-masked-mem-intrin, which runs
   // in pre-isel. Use ST directly instead of calling isHVXVectorType.
   return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index 195cc616347b0..1ca7408d75126 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -167,9 +167,9 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
                           unsigned AddressSpace,
-                          bool IsMaskConstant) const override;
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned AddressSpace) const override;
+                          TTI::MaskKind MaskKind) const override;
+  bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+                         TTI::MaskKind MaskKind) const override;
 
   /// @}
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 74c993ce6d7ea..ee08d941fb5ea 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -594,9 +594,8 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
 
 bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
                                       unsigned AddrSpace,
-                                      bool IsMaskConstant) const {
-
-  if (!IsMaskConstant)
+                                      TTI::MaskKind MaskKind) const {
+  if (MaskKind != TTI::MaskKind::ConstantMask)
     return false;
 
   //  We currently only support this feature for 256-bit vectors, so the
@@ -619,6 +618,17 @@ bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
   return false;
 }
 
+bool NVPTXTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
+                                     unsigned /*AddrSpace*/,
+                                     TTI::MaskKind MaskKind) const {
+  if (MaskKind != TTI::MaskKind::ConstantMask)
+    return false;
+
+  if (Alignment < DL.getTypeStoreSize(DataTy))
+    return false;
+  return true;
+}
+
 unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
   // 256 bit loads/stores are currently only supported for global address space
   if (ST->has256BitVectorLoadStore(AddrSpace))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 80cf9772f76c5..d7f4e1da4073b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -182,7 +182,10 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
                                   Intrinsic::ID IID) const override;
 
   bool isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddrSpace,
-                          bool IsMaskConstant) const override;
+                          TTI::MaskKind MaskKind) const override;
+
+  bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddrSpace,
+                         TTI::MaskKind MaskKind) const override;
 
   unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
 
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 8dfa748963a38..7988a6c35c768 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -286,12 +286,13 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
   }
 
   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned /*AddressSpace*/) const override {
+                         unsigned /*AddressSpace*/,
+                         TTI::MaskKind /*MaskKind*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
   bool isLegalMaskedStore(Type *DataType, Align Alignment,
                           unsigned /*AddressSpace*/,
-                          bool /*IsMaskConstant*/) const override {
+                          TTI::MaskKind /*MaskKind*/) const override {
     return isLegalMaskedLoadStore(DataType, Alignment);
   }
 
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 3f33760a5ead9..eed3832c9f1fb 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -134,13 +134,14 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
   }
 
   // Load & Store {
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned /*AddressSpace*/) const override {
+  bool
+  isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned /*AddressSpace*/,
+                    TargetTransformInfo::MaskKind /*MaskKind*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
-  bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned /*AddressSpace*/,
-                          bool /*IsMaskConstant*/) const override {
+  bool isLegalMaskedStore(
+      Type *DataType, Align Alignment, unsigned /*AddressSpace*/,
+      TargetTransformInfo::MaskKind /*MaskKind*/) const override {
     return isVectorLaneType(*getLaneType(DataType));
   }
   bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 851009f89e196..60f148f9f747e 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6317,7 +6317,8 @@ static bool isLegalMaskedLoadStore(Type *ScalarTy, const X86Subtarget *ST) {
 }
 
 bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
-                                   unsigned AddressSpace) const {
+                                   unsigned AddressSpace,
+                                   TTI::MaskKind MaskKind) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
@@ -6331,7 +6332,7 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
 
 bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
                                     unsigned AddressSpace,
-                                    bool IsMaskConstant) const {
+                                    TTI::MaskKind MaskKind) const {
   Type *ScalarTy = DataTy->getScalarType();
 
   // The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index c7c3d93600878..02b642f599270 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -268,11 +268,14 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                      const TargetTransformInfo::LSRCost &C2) const override;
   bool canMacroFuseCmp() const override;
-  bool isLegalMaskedLoad(Type *DataType, Align Alignment,
-                         unsigned AddressSpace) const override;
-  bool isLegalMaskedStore(Type *DataType, Align Alignment,
-                          unsigned AddressSpace,
-                          bool IsMaskConstant = false) const override;
+  bool
+  isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
+                    TTI::MaskKind MaskKind =
+                        TTI::MaskKind::VariableOrConstantMask) const override;
+  bool
+  isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace,
+                     TTI::MaskKind MaskKind =
+                         TTI::MaskKind::VariableOrConstantMask) const override;
   bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
   bool isLegalNTStore(Type *DataType, Align Alignment) const override;
   bool isLegalBroadcastLoad(Type *ElementTy,
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 2cb334d8d6952..ff78dd172f38d 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1123,7 +1123,10 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
       if (TTI.isLegalMaskedLoad(
               CI->getType(), CI->getParamAlign(0).valueOrOne(),
               cast<PointerType>(CI->getArgOperand(0)->getType())
-                  ->getAddressSpace()))
+                  ->getAddressSpace(),
+              isConstantIntVector(CI->getArgOperand(2))
+                  ? TTI::MaskKind::ConstantMask
+                  : TTI::MaskKind::VariableOrConstantMask))
         return false;
       scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
@@ -1133,7 +1136,9 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getParamAlign(1).valueOrOne(),
               cast<PointerType>(CI->getArgOperand(1)->getType())
                   ->getAddressSpace(),
-              isConstantIntVector(CI->getArgOperand(3))))
+              isConstantIntVector(CI->getArgOperand(3))
+                  ? TTI::MaskKind::ConstantMask
+                  : TTI::MaskKind::VariableOrConstantMask))
         return false;
       scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
       return true;
>From 6ec5661345e45c43421beafec7c0ab09ade9e5b4 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Fri, 17 Oct 2025 23:53:27 +0000
Subject: [PATCH 4/9] Add masked load lowering support
---
 .../SelectionDAG/LegalizeVectorTypes.cpp      |  10 +-
 .../SelectionDAG/SelectionDAGBuilder.cpp      |   2 +
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp   |  10 +
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.h     |   1 +
 llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp  |   2 +-
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp   |  33 +-
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 112 +++++-
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h     |   2 +
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |  24 +-
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td      |  17 +-
 .../Target/NVPTX/NVPTXReplaceImageHandles.cpp |   4 +-
 .../Target/NVPTX/NVPTXTagInvariantLoads.cpp   |  25 +-
 .../floating-point-immediate-operands.mir     |   8 +-
 .../test/CodeGen/NVPTX/masked-load-vectors.ll | 367 ++++++++++++++++++
 llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir |   4 +-
 15 files changed, 584 insertions(+), 37 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 3b5f83f7c089a..6bb564ee56018 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -2449,6 +2449,7 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
   SDValue PassThru = MLD->getPassThru();
   Align Alignment = MLD->getBaseAlign();
   ISD::LoadExtType ExtType = MLD->getExtensionType();
+  MachineMemOperand::Flags MMOFlags = MLD->getMemOperand()->getFlags();
 
   // Split Mask operand
   SDValue MaskLo, MaskHi;
@@ -2474,9 +2475,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
     std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, dl);
 
   MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
-      MLD->getPointerInfo(), MachineMemOperand::MOLoad,
-      LocationSize::beforeOrAfterPointer(), Alignment, MLD->getAAInfo(),
-      MLD->getRanges());
+      MLD->getPointerInfo(), MMOFlags, LocationSize::beforeOrAfterPointer(),
+      Alignment, MLD->getAAInfo(), MLD->getRanges());
 
   Lo = DAG.getMaskedLoad(LoVT, dl, Ch, Ptr, Offset, MaskLo, PassThruLo, LoMemVT,
                          MMO, MLD->getAddressingMode(), ExtType,
@@ -2499,8 +2499,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
           LoMemVT.getStoreSize().getFixedValue());
 
     MMO = DAG.getMachineFunction().getMachineMemOperand(
-        MPI, MachineMemOperand::MOLoad, LocationSize::beforeOrAfterPointer(),
-        Alignment, MLD->getAAInfo(), MLD->getRanges());
+        MPI, MMOFlags, LocationSize::beforeOrAfterPointer(), Alignment,
+        MLD->getAAInfo(), MLD->getRanges());
 
     Hi = DAG.getMaskedLoad(HiVT, dl, Ch, Ptr, Offset, MaskHi, PassThruHi,
                            HiMemVT, MMO, MLD->getAddressingMode(), ExtType,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 20a0efd3afa1c..4cc111e3941ed 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -5010,6 +5010,8 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
   auto MMOFlags = MachineMemOperand::MOLoad;
   if (I.hasMetadata(LLVMContext::MD_nontemporal))
     MMOFlags |= MachineMemOperand::MONonTemporal;
+  if (I.hasMetadata(LLVMContext::MD_invariant_load))
+    MMOFlags |= MachineMemOperand::MOInvariant;
 
   MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
       MachinePointerInfo(PtrOperand), MMOFlags,
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index c8717876157c5..8ca3cb46b5455 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -395,6 +395,16 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
   }
 }
 
+void NVPTXInstPrinter::printUsedBytesMaskPragma(const MCInst *MI, int OpNum,
+                                                raw_ostream &O) {
+  auto &Op = MI->getOperand(OpNum);
+  assert(Op.isImm() && "Invalid operand");
+  uint32_t Imm = (uint32_t)Op.getImm();
+  if (Imm != UINT32_MAX) {
+    O << ".pragma \"used_bytes_mask " << Imm << "\";\n\t";
+  }
+}
+
 void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
                                                  raw_ostream &O,
                                                  const char *Modifier) {
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index d373668aa591f..89137a954d2d8 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,6 +46,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
                     StringRef Modifier = {});
   void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
                        StringRef Modifier = {});
+  void printUsedBytesMaskPragma(const MCInst *MI, int OpNum, raw_ostream &O);
   void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
                                  const char *Modifier = nullptr);
   void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
index a3496090def3c..c8b53571c1e59 100644
--- a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
@@ -96,7 +96,7 @@ static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
   const MachineOperand *ParamSymbol = Mov.uses().begin();
   assert(ParamSymbol->isSymbol());
 
-  constexpr unsigned LDInstBasePtrOpIdx = 5;
+  constexpr unsigned LDInstBasePtrOpIdx = 6;
   constexpr unsigned LDInstAddrSpaceOpIdx = 2;
   for (auto *LI : LoadInsts) {
     (LI->uses().begin() + LDInstBasePtrOpIdx)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 7e7ee754c250d..8e0399b493a24 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -105,6 +105,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
   switch (N->getOpcode()) {
   case ISD::LOAD:
   case ISD::ATOMIC_LOAD:
+  case NVPTXISD::MLoadV1:
     if (tryLoad(N))
       return;
     break;
@@ -1132,6 +1133,19 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
           ? NVPTX::PTXLdStInstCode::Signed
           : NVPTX::PTXLdStInstCode::Untyped;
 
+  uint32_t UsedBytesMask;
+  switch (N->getOpcode()) {
+  case ISD::LOAD:
+  case ISD::ATOMIC_LOAD:
+    UsedBytesMask = UINT32_MAX;
+    break;
+  case NVPTXISD::MLoadV1:
+    UsedBytesMask = N->getConstantOperandVal(N->getNumOperands() - 2);
+    break;
+  default:
+    llvm_unreachable("Unexpected opcode");
+  }
+
   assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
          FromTypeWidth <= 128 && "Invalid width for load");
 
@@ -1142,6 +1156,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
                    getI32Imm(CodeAddrSpace, DL),
                    getI32Imm(FromType, DL),
                    getI32Imm(FromTypeWidth, DL),
+                   getI32Imm(UsedBytesMask, DL),
                    Base,
                    Offset,
                    Chain};
@@ -1204,6 +1219,8 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
                                 : NVPTX::PTXLdStInstCode::Untyped;
 
   const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
+  const uint32_t UsedBytesMask =
+      N->getConstantOperandVal(N->getNumOperands() - 2);
 
   assert(!(EltVT.isVector() && ExtensionType != ISD::NON_EXTLOAD));
 
@@ -1213,6 +1230,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
                    getI32Imm(CodeAddrSpace, DL),
                    getI32Imm(FromType, DL),
                    getI32Imm(FromTypeWidth, DL),
+                   getI32Imm(UsedBytesMask, DL),
                    Base,
                    Offset,
                    Chain};
@@ -1250,10 +1268,13 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
   SDLoc DL(LD);
 
   unsigned ExtensionType;
+  uint32_t UsedBytesMask;
   if (const auto *Load = dyn_cast<LoadSDNode>(LD)) {
     ExtensionType = Load->getExtensionType();
+    UsedBytesMask = UINT32_MAX;
   } else {
     ExtensionType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
+    UsedBytesMask = LD->getConstantOperandVal(LD->getNumOperands() - 2);
   }
   const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
                                 ? NVPTX::PTXLdStInstCode::Signed
@@ -1265,8 +1286,12 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
            ExtensionType != ISD::NON_EXTLOAD));
 
   const auto [Base, Offset] = selectADDR(LD->getOperand(1), CurDAG);
-  SDValue Ops[] = {getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), Base,
-                   Offset, LD->getChain()};
+  SDValue Ops[] = {getI32Imm(FromType, DL),
+                   getI32Imm(FromTypeWidth, DL),
+                   getI32Imm(UsedBytesMask, DL),
+                   Base,
+                   Offset,
+                   LD->getChain()};
 
   const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
   std::optional<unsigned> Opcode;
@@ -1277,6 +1302,10 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
     Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i16,
                              NVPTX::LD_GLOBAL_NC_i32, NVPTX::LD_GLOBAL_NC_i64);
     break;
+  case NVPTXISD::MLoadV1:
+    Opcode = pickOpcodeForVT(TargetVT, std::nullopt, NVPTX::LD_GLOBAL_NC_i32,
+                             NVPTX::LD_GLOBAL_NC_i64);
+    break;
   case NVPTXISD::LoadV2:
     Opcode =
         pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v2i16,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index e52854c3b5627..cc23d1159932f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -769,7 +769,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
   for (MVT VT : MVT::fixedlen_vector_valuetypes())
     if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
-      setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
+      setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT,
+                         Custom);
 
   // Custom legalization for LDU intrinsics.
   // TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -1130,6 +1131,7 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::LoadV2)
     MAKE_CASE(NVPTXISD::LoadV4)
     MAKE_CASE(NVPTXISD::LoadV8)
+    MAKE_CASE(NVPTXISD::MLoadV1)
     MAKE_CASE(NVPTXISD::LDUV2)
     MAKE_CASE(NVPTXISD::LDUV4)
     MAKE_CASE(NVPTXISD::StoreV2)
@@ -3306,6 +3308,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   }
   case ISD::LOAD:
     return LowerLOAD(Op, DAG);
+  case ISD::MLOAD:
+    return LowerMLOAD(Op, DAG);
   case ISD::SHL_PARTS:
     return LowerShiftLeftParts(Op, DAG);
   case ISD::SRA_PARTS:
@@ -3497,10 +3501,58 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
                       MachinePointerInfo(SV));
 }
 
+static std::tuple<MemSDNode *, uint32_t>
+convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
+  SDValue Chain = N->getOperand(0);
+  SDValue BasePtr = N->getOperand(1);
+  SDValue Mask = N->getOperand(3);
+  SDValue Passthru = N->getOperand(4);
+
+  SDLoc DL(N);
+  EVT ResVT = N->getValueType(0);
+  assert(ResVT.isVector() && "Masked vector load must have vector type");
+  // While we only expect poison passthru vectors as an input to the backend,
+  // when the legalization framework splits a poison vector in half, it creates
+  // two undef vectors, so we can technically expect those too.
+  assert((Passthru.getOpcode() == ISD::POISON ||
+          Passthru.getOpcode() == ISD::UNDEF) &&
+         "Passthru operand expected to be poison or undef");
+
+  // Extract the mask and convert it to a uint32_t representing the used bytes
+  // of the entire vector load
+  uint32_t UsedBytesMask = 0;
+  uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
+  assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
+  uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
+  uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
+
+  for (unsigned I :
+       llvm::reverse(llvm::seq<unsigned>(0, ResVT.getVectorNumElements()))) {
+    assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
+           "Mask elements must be constants");
+    // We technically only want to do this shift for every iteration *but* the
+    // first, but in the first iteration NewMask is 0, so this shift is a
+    // no-op.
+    UsedBytesMask <<= ElementSizeInBytes;
+
+    if (Mask->getConstantOperandVal(I) != 0)
+      UsedBytesMask |= ElementMask;
+  }
+
+  assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
+         "Unexpected masked load with elements masked all on or all off");
+
+  // Create a new load sd node to be handled normally by ReplaceLoadVector.
+  MemSDNode *NewLD = cast<MemSDNode>(
+      DAG.getLoad(ResVT, DL, Chain, BasePtr, N->getMemOperand()).getNode());
+
+  return {NewLD, UsedBytesMask};
+}
+
 /// replaceLoadVector - Convert vector loads into multi-output scalar loads.
 static std::optional<std::pair<SDValue, SDValue>>
 replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
-  LoadSDNode *LD = cast<LoadSDNode>(N);
+  MemSDNode *LD = cast<MemSDNode>(N);
   const EVT ResVT = LD->getValueType(0);
   const EVT MemVT = LD->getMemoryVT();
 
@@ -3527,6 +3579,14 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
     return std::nullopt;
   }
 
+  // If we have a masked load, convert it to a normal load now
+  std::optional<uint32_t> UsedBytesMask = std::nullopt;
+  if (LD->getOpcode() == ISD::MLOAD) {
+    auto Result = convertMLOADToLoadWithUsedBytesMask(LD, DAG);
+    LD = std::get<0>(Result);
+    UsedBytesMask = std::get<1>(Result);
+  }
+
   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
   // loaded type to i16 and propagate the "real" type as the memory type.
@@ -3555,9 +3615,13 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
   // Copy regular operands
   SmallVector<SDValue, 8> OtherOps(LD->ops());
 
+  OtherOps.push_back(
+      DAG.getConstant(UsedBytesMask.value_or(UINT32_MAX), DL, MVT::i32));
+
   // The select routine does not have access to the LoadSDNode instance, so
   // pass along the extension information
-  OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
+  OtherOps.push_back(
+      DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
 
   SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT,
                                           LD->getMemOperand());
@@ -3645,6 +3709,43 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
   llvm_unreachable("Unexpected custom lowering for load");
 }
 
+SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
+  // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
+  // masked loads of these types and have to handle them here.
+  // v2f32 also needs to be handled here if the subtarget has f32x2
+  // instructions, making it legal.
+  //
+  // Note: misaligned masked loads should never reach this point
+  // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
+  // will validate alignment. Therefore, we do not need to special case handle
+  // them here.
+  EVT VT = Op.getValueType();
+  if (NVPTX::isPackedVectorTy(VT) &&
+      (VT != MVT::v2f32 || STI.hasF32x2Instructions())) {
+    auto Result =
+        convertMLOADToLoadWithUsedBytesMask(cast<MemSDNode>(Op.getNode()), DAG);
+    MemSDNode *LD = std::get<0>(Result);
+    uint32_t UsedBytesMask = std::get<1>(Result);
+
+    SDLoc DL(LD);
+
+    // Copy regular operands
+    SmallVector<SDValue, 8> OtherOps(LD->ops());
+
+    OtherOps.push_back(DAG.getConstant(UsedBytesMask, DL, MVT::i32));
+
+    // The select routine does not have access to the LoadSDNode instance, so
+    // pass along the extension information
+    OtherOps.push_back(
+        DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
+    SDValue NewLD = DAG.getMemIntrinsicNode(
+        NVPTXISD::MLoadV1, DL, LD->getVTList(), OtherOps, LD->getMemoryVT(),
+        LD->getMemOperand());
+    return NewLD;
+  }
+  return SDValue();
+}
+
 static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG,
                                 const NVPTXSubtarget &STI) {
   MemSDNode *N = cast<MemSDNode>(Op.getNode());
@@ -5555,9 +5656,13 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
     // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
     // here.
     Opcode = NVPTXISD::LoadV2;
+    // append a "full" used bytes mask operand right before the extension type
+    // operand, signifying that all bytes are used.
+    Operands.push_back(DCI.DAG.getConstant(UINT32_MAX, DL, MVT::i32));
     Operands.push_back(DCI.DAG.getIntPtrConstant(
         cast<LoadSDNode>(LD)->getExtensionType(), DL));
     break;
+  // TODO do we need to support MLoadV1 here?
   case NVPTXISD::LoadV2:
     OldNumOutputs = 2;
     Opcode = NVPTXISD::LoadV4;
@@ -6793,6 +6898,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
     ReplaceBITCAST(N, DAG, Results);
     return;
   case ISD::LOAD:
+  case ISD::MLOAD:
     replaceLoadVector(N, DAG, Results, STI);
     return;
   case ISD::INTRINSIC_W_CHAIN:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 63fa0bb9159ff..89bf0c290292a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -99,6 +99,7 @@ enum NodeType : unsigned {
   LoadV2,
   LoadV4,
   LoadV8,
+  MLoadV1,
   LDUV2, // LDU.v2
   LDUV4, // LDU.v4
   StoreV2,
@@ -349,6 +350,7 @@ class NVPTXTargetLowering : public TargetLowering {
   SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
 
   SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index a0d5b09253c32..24a40ab627a30 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1575,6 +1575,10 @@ def ADDR : Operand<pAny> {
   let MIOperandInfo = (ops ADDR_base, i32imm);
 }
 
+def UsedBytesMask : Operand<i32> {
+  let PrintMethod = "printUsedBytesMaskPragma";
+}
+
 def RegOrSink : Operand<Any> {
   let PrintMethod = "printRegisterOrSinkSymbol";
 }
@@ -1817,8 +1821,10 @@ def Callseq_End :
 class LD<NVPTXRegClass regclass>
   : NVPTXInst<
     (outs regclass:$dst),
-    (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, AtomicCode:$Sign,
-         i32imm:$fromWidth, ADDR:$addr),
+    (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
+         AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+         ADDR:$addr),
+    "${usedBytes}"
     "ld${sem:sem}${scope:scope}${addsp:addsp}.${Sign:sign}$fromWidth "
     "\t$dst, [$addr];">;
 
@@ -1850,21 +1856,27 @@ multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
   def _v2 : NVPTXInst<
     (outs regclass:$dst1, regclass:$dst2),
     (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
-         AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
+         AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+         ADDR:$addr),
+    "${usedBytes}"
     "ld${sem:sem}${scope:scope}${addsp:addsp}.v2.${Sign:sign}$fromWidth "
     "\t{{$dst1, $dst2}}, [$addr];">;
   def _v4 : NVPTXInst<
     (outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4),
     (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
-         AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
+         AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+         ADDR:$addr),
+    "${usedBytes}"
     "ld${sem:sem}${scope:scope}${addsp:addsp}.v4.${Sign:sign}$fromWidth "
     "\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr];">;
   if support_v8 then
     def _v8 : NVPTXInst<
       (outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
             regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
-      (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, AtomicCode:$Sign,
-           i32imm:$fromWidth, ADDR:$addr),
+      (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp,
+           AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+           ADDR:$addr),
+      "${usedBytes}"
       "ld${sem:sem}${scope:scope}${addsp:addsp}.v8.${Sign:sign}$fromWidth "
       "\t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, "
       "[$addr];">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 22cf3a7eef2c1..65c951e7aa06d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2351,7 +2351,10 @@ def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<B32>;
 // during the lifetime of the kernel.
 
 class LDG_G<NVPTXRegClass regclass>
-  : NVPTXInst<(outs regclass:$result), (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+  : NVPTXInst<(outs regclass:$result),
+              (ins AtomicCode:$Sign, i32imm:$fromWidth,
+                   UsedBytesMask:$usedBytes, ADDR:$src),
+               "${usedBytes}"
                "ld.global.nc.${Sign:sign}$fromWidth \t$result, [$src];">;
 
 def LD_GLOBAL_NC_i16 : LDG_G<B16>;
@@ -2363,19 +2366,25 @@ def LD_GLOBAL_NC_i64 : LDG_G<B64>;
 // Elementized vector ldg
 class VLDG_G_ELE_V2<NVPTXRegClass regclass> :
   NVPTXInst<(outs regclass:$dst1, regclass:$dst2),
-            (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+            (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+             ADDR:$src),
+            "${usedBytes}"
             "ld.global.nc.v2.${Sign:sign}$fromWidth \t{{$dst1, $dst2}}, [$src];">;
 
 
 class VLDG_G_ELE_V4<NVPTXRegClass regclass> :
   NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4), 
-            (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+            (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+             ADDR:$src),
+            "${usedBytes}"
             "ld.global.nc.v4.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];">;
 
 class VLDG_G_ELE_V8<NVPTXRegClass regclass> :
   NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
                   regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
-             (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+            (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+             ADDR:$src),
+            "${usedBytes}"
              "ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];">;
 
 // FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
diff --git a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
index 320c0fb6950a7..4bbf49f93f43b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
@@ -1808,8 +1808,8 @@ bool NVPTXReplaceImageHandles::replaceImageHandle(MachineOperand &Op,
       // For CUDA, we preserve the param loads coming from function arguments
       return false;
 
-    assert(TexHandleDef.getOperand(6).isSymbol() && "Load is not a symbol!");
-    StringRef Sym = TexHandleDef.getOperand(6).getSymbolName();
+    assert(TexHandleDef.getOperand(7).isSymbol() && "Load is not a symbol!");
+    StringRef Sym = TexHandleDef.getOperand(7).getSymbolName();
     InstrsToRemove.insert(&TexHandleDef);
     Op.ChangeToES(Sym.data());
     MFI->getImageHandleSymbolIndex(Sym);
diff --git a/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp b/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
index a4aff44ac04f6..6fa518e8d409b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp
@@ -27,13 +27,14 @@
 
 using namespace llvm;
 
-static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) {
+static bool isInvariantLoad(const Instruction *I, const Value *Ptr,
+                            const bool IsKernelFn) {
   // Don't bother with non-global loads
-  if (LI->getPointerAddressSpace() != NVPTXAS::ADDRESS_SPACE_GLOBAL)
+  if (Ptr->getType()->getPointerAddressSpace() != NVPTXAS::ADDRESS_SPACE_GLOBAL)
     return false;
 
   // If the load is already marked as invariant, we don't need to do anything
-  if (LI->getMetadata(LLVMContext::MD_invariant_load))
+  if (I->getMetadata(LLVMContext::MD_invariant_load))
     return false;
 
   // We use getUnderlyingObjects() here instead of getUnderlyingObject()
@@ -41,7 +42,7 @@ static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) {
   // not. We need to look through phi nodes to handle pointer induction
   // variables.
   SmallVector<const Value *, 8> Objs;
-  getUnderlyingObjects(LI->getPointerOperand(), Objs);
+  getUnderlyingObjects(Ptr, Objs);
 
   return all_of(Objs, [&](const Value *V) {
     if (const auto *A = dyn_cast<const Argument>(V))
@@ -53,9 +54,9 @@ static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) {
   });
 }
 
-static void markLoadsAsInvariant(LoadInst *LI) {
-  LI->setMetadata(LLVMContext::MD_invariant_load,
-                  MDNode::get(LI->getContext(), {}));
+static void markLoadsAsInvariant(Instruction *I) {
+  I->setMetadata(LLVMContext::MD_invariant_load,
+                 MDNode::get(I->getContext(), {}));
 }
 
 static bool tagInvariantLoads(Function &F) {
@@ -64,10 +65,18 @@ static bool tagInvariantLoads(Function &F) {
   bool Changed = false;
   for (auto &I : instructions(F)) {
     if (auto *LI = dyn_cast<LoadInst>(&I)) {
-      if (isInvariantLoad(LI, IsKernelFn)) {
+      if (isInvariantLoad(LI, LI->getPointerOperand(), IsKernelFn)) {
         markLoadsAsInvariant(LI);
         Changed = true;
       }
+      if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
+        if (II->getIntrinsicID() == Intrinsic::masked_load) {
+          if (isInvariantLoad(II, II->getOperand(0), IsKernelFn)) {
+            markLoadsAsInvariant(II);
+            Changed = true;
+          }
+        }
+      }
     }
   }
   return Changed;
diff --git a/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir b/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
index e3b072549bc04..3158916a3195c 100644
--- a/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
+++ b/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
@@ -40,9 +40,9 @@ registers:
   - { id: 7, class: b32 }
 body: |
   bb.0.entry:
-    %0 = LD_i32 0, 0, 4, 2, 32, &test_param_0, 0
+    %0 = LD_i32 0, 0, 4, 2, 32, -1, &test_param_0, 0
     %1 = CVT_f64_f32 %0, 0
-    %2 = LD_i32 0, 0, 4, 0, 32, &test_param_1, 0
+    %2 = LD_i32 0, 0, 4, 0, 32, -1, &test_param_1, 0
   ; CHECK: %3:b64 = FADD_rnf64ri %1, double 3.250000e+00
     %3 = FADD_rnf64ri %1, double 3.250000e+00
     %4 = CVT_f32_f64 %3, 5
@@ -66,9 +66,9 @@ registers:
   - { id: 7, class: b32 }
 body: |
   bb.0.entry:
-    %0 = LD_i32 0, 0, 4, 2, 32, &test2_param_0, 0
+    %0 = LD_i32 0, 0, 4, 2, 32, -1, &test2_param_0, 0
     %1 = CVT_f64_f32 %0, 0
-    %2 = LD_i32 0, 0, 4, 0, 32, &test2_param_1, 0
+    %2 = LD_i32 0, 0, 4, 0, 32, -1, &test2_param_1, 0
   ; CHECK: %3:b64 = FADD_rnf64ri %1, double 0x7FF8000000000000
     %3 = FADD_rnf64ri %1, double 0x7FF8000000000000
     %4 = CVT_f32_f64 %3, 5
diff --git a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
new file mode 100644
index 0000000000000..2e58aae6ad478
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
@@ -0,0 +1,367 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s -check-prefixes=CHECK,SM90
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %}
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK,SM100
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+
+; Different architectures are tested in this file for the following reasons:
+; - SM90 does not have 256-bit load/store instructions
+; - SM90 does not have masked store instructions
+; - SM90 does not support packed f32x2 instructions
+
+define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM90-NEXT:    .pragma "used_bytes_mask 61440";
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    .pragma "used_bytes_mask 3855";
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT:    st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xi32(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xi32_param_0];
+; SM100-NEXT:    .pragma "used_bytes_mask 4026535695";
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT:    ret;
+  %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) %a, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+; Masked stores are only supported for 32 byte element types,
+; Masked stores are only supported for 32-bit element types,
+; while masked loads are supported for all element types.
+define void @global_16xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_16xi16(
+; SM90:       {
+; SM90-NEXT:    .reg .b16 %rs<7>;
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_16xi16_param_0];
+; SM90-NEXT:    .pragma "used_bytes_mask 61440";
+; SM90-NEXT:    ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    mov.b32 {%rs1, %rs2}, %r4;
+; SM90-NEXT:    .pragma "used_bytes_mask 3855";
+; SM90-NEXT:    ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    mov.b32 {%rs3, %rs4}, %r7;
+; SM90-NEXT:    mov.b32 {%rs5, %rs6}, %r5;
+; SM90-NEXT:    ld.param.b64 %rd2, [global_16xi16_param_1];
+; SM90-NEXT:    st.global.b16 [%rd2], %rs5;
+; SM90-NEXT:    st.global.b16 [%rd2+2], %rs6;
+; SM90-NEXT:    st.global.b16 [%rd2+8], %rs3;
+; SM90-NEXT:    st.global.b16 [%rd2+10], %rs4;
+; SM90-NEXT:    st.global.b16 [%rd2+28], %rs1;
+; SM90-NEXT:    st.global.b16 [%rd2+30], %rs2;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_16xi16(
+; SM100:       {
+; SM100-NEXT:    .reg .b16 %rs<7>;
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_16xi16_param_0];
+; SM100-NEXT:    .pragma "used_bytes_mask 4026535695";
+; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    mov.b32 {%rs1, %rs2}, %r8;
+; SM100-NEXT:    mov.b32 {%rs3, %rs4}, %r3;
+; SM100-NEXT:    mov.b32 {%rs5, %rs6}, %r1;
+; SM100-NEXT:    ld.param.b64 %rd2, [global_16xi16_param_1];
+; SM100-NEXT:    st.global.b16 [%rd2], %rs5;
+; SM100-NEXT:    st.global.b16 [%rd2+2], %rs6;
+; SM100-NEXT:    st.global.b16 [%rd2+8], %rs3;
+; SM100-NEXT:    st.global.b16 [%rd2+10], %rs4;
+; SM100-NEXT:    st.global.b16 [%rd2+28], %rs1;
+; SM100-NEXT:    st.global.b16 [%rd2+30], %rs2;
+; SM100-NEXT:    ret;
+  %a.load = tail call <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1) %a, i32 32, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>, <16 x i16> poison)
+  tail call void @llvm.masked.store.v16i16.p1(<16 x i16> %a.load, ptr addrspace(1) %b, i32 32, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>)
+  ret void
+}
+
+define void @global_8xi32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_8xi32_no_align(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_8xi32_no_align_param_0];
+; CHECK-NEXT:    ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_8xi32_no_align_param_1];
+; CHECK-NEXT:    ld.global.b32 %r2, [%rd1+8];
+; CHECK-NEXT:    ld.global.b32 %r3, [%rd1+28];
+; CHECK-NEXT:    st.global.b32 [%rd2], %r1;
+; CHECK-NEXT:    st.global.b32 [%rd2+8], %r2;
+; CHECK-NEXT:    st.global.b32 [%rd2+28], %r3;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) %a, i32 16, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 16, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+
+define void @global_8xi32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_8xi32_invariant(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<9>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_8xi32_invariant_param_0];
+; SM90-NEXT:    .pragma "used_bytes_mask 61440";
+; SM90-NEXT:    ld.global.nc.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16];
+; SM90-NEXT:    .pragma "used_bytes_mask 3855";
+; SM90-NEXT:    ld.global.nc.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_8xi32_invariant_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r5;
+; SM90-NEXT:    st.global.b32 [%rd2+8], %r7;
+; SM90-NEXT:    st.global.b32 [%rd2+28], %r4;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_8xi32_invariant(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<9>;
+; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_8xi32_invariant_param_0];
+; SM100-NEXT:    .pragma "used_bytes_mask 4026535695";
+; SM100-NEXT:    ld.global.nc.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_invariant_param_1];
+; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
+; SM100-NEXT:    ret;
+  %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) %a, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison), !invariant.load !0
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  ret void
+}
+
+define void @global_2xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xi16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_2xi16_param_0];
+; CHECK-NEXT:    .pragma "used_bytes_mask 3";
+; CHECK-NEXT:    ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_2xi16_param_1];
+; CHECK-NEXT:    mov.b32 {%rs1, _}, %r1;
+; CHECK-NEXT:    st.global.b16 [%rd2], %rs1;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) %a, i32 4, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
+  tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) %b, i32 4, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @global_2xi16_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xi16_invariant(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_2xi16_invariant_param_0];
+; CHECK-NEXT:    .pragma "used_bytes_mask 3";
+; CHECK-NEXT:    ld.global.nc.b32 %r1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_2xi16_invariant_param_1];
+; CHECK-NEXT:    mov.b32 {%rs1, _}, %r1;
+; CHECK-NEXT:    st.global.b16 [%rd2], %rs1;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) %a, i32 4, <2 x i1> <i1 true, i1 false>, <2 x i16> poison), !invariant.load !0
+  tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) %b, i32 4, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @global_2xi16_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xi16_no_align(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_2xi16_no_align_param_0];
+; CHECK-NEXT:    ld.global.b16 %rs1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_2xi16_no_align_param_1];
+; CHECK-NEXT:    st.global.b16 [%rd2], %rs1;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) %a, i32 2, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
+  tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) %b, i32 4, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @global_4xi8(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_4xi8(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_4xi8_param_0];
+; CHECK-NEXT:    .pragma "used_bytes_mask 5";
+; CHECK-NEXT:    ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_4xi8_param_1];
+; CHECK-NEXT:    st.global.b8 [%rd2], %r1;
+; CHECK-NEXT:    prmt.b32 %r2, %r1, 0, 0x7772U;
+; CHECK-NEXT:    st.global.b8 [%rd2+2], %r2;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) %a, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
+  tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) %b, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+define void @global_4xi8_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_4xi8_invariant(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_4xi8_invariant_param_0];
+; CHECK-NEXT:    .pragma "used_bytes_mask 5";
+; CHECK-NEXT:    ld.global.nc.b32 %r1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_4xi8_invariant_param_1];
+; CHECK-NEXT:    st.global.b8 [%rd2], %r1;
+; CHECK-NEXT:    prmt.b32 %r2, %r1, 0, 0x7772U;
+; CHECK-NEXT:    st.global.b8 [%rd2+2], %r2;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) %a, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison), !invariant.load !0
+  tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) %b, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+define void @global_4xi8_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_4xi8_no_align(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_4xi8_no_align_param_0];
+; CHECK-NEXT:    ld.global.b8 %rs1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_4xi8_no_align_param_1];
+; CHECK-NEXT:    ld.global.b8 %rs2, [%rd1+2];
+; CHECK-NEXT:    st.global.b8 [%rd2], %rs1;
+; CHECK-NEXT:    st.global.b8 [%rd2+2], %rs2;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) %a, i32 2, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
+  tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) %b, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}
+
+; In sm100+, we pack 2xf32 loads into a single b64 load while lowering
+define void @global_2xf32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_2xf32(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<3>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_2xf32_param_0];
+; SM90-NEXT:    .pragma "used_bytes_mask 15";
+; SM90-NEXT:    ld.global.v2.b32 {%r1, %r2}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_2xf32_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r1;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_2xf32(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<2>;
+; SM100-NEXT:    .reg .b64 %rd<4>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_2xf32_param_0];
+; SM100-NEXT:    .pragma "used_bytes_mask 15";
+; SM100-NEXT:    ld.global.b64 %rd2, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd3, [global_2xf32_param_1];
+; SM100-NEXT:    mov.b64 {%r1, _}, %rd2;
+; SM100-NEXT:    st.global.b32 [%rd3], %r1;
+; SM100-NEXT:    ret;
+  %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) %a, i32 8, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
+  tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) %b, i32 8, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @global_2xf32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: global_2xf32_invariant(
+; SM90:       {
+; SM90-NEXT:    .reg .b32 %r<3>;
+; SM90-NEXT:    .reg .b64 %rd<3>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [global_2xf32_invariant_param_0];
+; SM90-NEXT:    .pragma "used_bytes_mask 15";
+; SM90-NEXT:    ld.global.nc.v2.b32 {%r1, %r2}, [%rd1];
+; SM90-NEXT:    ld.param.b64 %rd2, [global_2xf32_invariant_param_1];
+; SM90-NEXT:    st.global.b32 [%rd2], %r1;
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: global_2xf32_invariant(
+; SM100:       {
+; SM100-NEXT:    .reg .b32 %r<2>;
+; SM100-NEXT:    .reg .b64 %rd<4>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [global_2xf32_invariant_param_0];
+; SM100-NEXT:    .pragma "used_bytes_mask 15";
+; SM100-NEXT:    ld.global.nc.b64 %rd2, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd3, [global_2xf32_invariant_param_1];
+; SM100-NEXT:    mov.b64 {%r1, _}, %rd2;
+; SM100-NEXT:    st.global.b32 [%rd3], %r1;
+; SM100-NEXT:    ret;
+  %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) %a, i32 8, <2 x i1> <i1 true, i1 false>, <2 x float> poison), !invariant.load !0
+  tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) %b, i32 8, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @global_2xf32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: global_2xf32_no_align(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [global_2xf32_no_align_param_0];
+; CHECK-NEXT:    ld.global.b32 %r1, [%rd1];
+; CHECK-NEXT:    ld.param.b64 %rd2, [global_2xf32_no_align_param_1];
+; CHECK-NEXT:    st.global.b32 [%rd2], %r1;
+; CHECK-NEXT:    ret;
+  %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) %a, i32 4, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
+  tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) %b, i32 8, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+declare <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1), i32, <8 x i1>, <8 x i32>)
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), i32, <8 x i1>)
+declare <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1), i32, <16 x i1>, <16 x i16>)
+declare void @llvm.masked.store.v16i16.p1(<16 x i16>, ptr addrspace(1), i32, <16 x i1>)
+declare <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1), i32, <2 x i1>, <2 x i16>)
+declare void @llvm.masked.store.v2i16.p1(<2 x i16>, ptr addrspace(1), i32, <2 x i1>)
+declare <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1), i32, <4 x i1>, <4 x i8>)
+declare void @llvm.masked.store.v4i8.p1(<4 x i8>, ptr addrspace(1), i32, <4 x i1>)
+declare <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1), i32, <2 x i1>, <2 x float>)
+declare void @llvm.masked.store.v2f32.p1(<2 x float>, ptr addrspace(1), i32, <2 x i1>)
+!0 = !{}
diff --git a/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir b/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir
index dfc84177fb0e6..a84b7fcd33836 100644
--- a/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir
+++ b/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir
@@ -77,7 +77,7 @@ constants:       []
 machineFunctionInfo: {}
 body:             |
   bb.0:
-    %0:b32, %1:b32, %2:b32, %3:b32 = LDV_i32_v4 0, 0, 101, 3, 32, &retval0, 0 :: (load (s128), addrspace 101)
+    %0:b32, %1:b32, %2:b32, %3:b32 = LDV_i32_v4 0, 0, 101, 3, 32, -1, &retval0, 0 :: (load (s128), addrspace 101)
     ; CHECK-NOT: ProxyReg
     %4:b32 = ProxyRegB32 killed %0
     %5:b32 = ProxyRegB32 killed %1
@@ -86,7 +86,7 @@ body:             |
     ; CHECK: STV_i32_v4 killed %0, killed %1, killed %2, killed %3
     STV_i32_v4 killed %4, killed %5, killed %6, killed %7, 0, 0, 101, 32, &func_retval0, 0 :: (store (s128), addrspace 101)
 
-    %8:b32 = LD_i32 0, 0, 101, 3, 32, &retval0, 0 :: (load (s32), addrspace 101)
+    %8:b32 = LD_i32 0, 0, 101, 3, 32, -1, &retval0, 0 :: (load (s32), addrspace 101)
     ; CHECK-NOT: ProxyReg
     %9:b32 = ProxyRegB32 killed %8
     %10:b32 = ProxyRegB32 killed %9
>From c44f8a2d8c528230fd303d847d15c3335c87c9a4 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Sat, 18 Oct 2025 00:27:50 +0000
Subject: [PATCH 5/9] Fix typo
---
 llvm/test/CodeGen/NVPTX/masked-load-vectors.ll | 1 -
 1 file changed, 1 deletion(-)
diff --git a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
index 2e58aae6ad478..5b4136ec1307e 100644
--- a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
+++ b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
@@ -45,7 +45,6 @@ define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
   ret void
 }
 
-; Masked stores are only supported for 32 byte element types,
 ; Masked stores are only supported for 32-bit element types,
 ; while masked loads are supported for all element types.
 define void @global_16xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
>From abbf0ba102a8ecd1334c4fbd88fd12e1463d5b55 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Tue, 21 Oct 2025 16:39:12 +0000
Subject: [PATCH 6/9] Review feedback
---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 30 ++++++++-----------
 .../Target/NVPTX/NVPTXTargetTransformInfo.cpp |  9 ++----
 2 files changed, 15 insertions(+), 24 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index cc23d1159932f..6a5ea30ea0a0a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3230,10 +3230,9 @@ static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
   assert(Mask.getValueType().getVectorNumElements() ==
              ValVT.getVectorNumElements() &&
          "Mask size must be the same as the vector size");
-  for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
-    assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
-           "Mask elements must be constants");
-    if (Mask->getConstantOperandVal(I) == 0) {
+  for (auto [I, Op] : enumerate(Mask->ops())) {
+    // Mask elements must be constants.
+    if (Op.getNode()->getAsZExtVal() == 0) {
       // Append a sentinel register 0 to the Ops vector to represent a masked
       // off element, this will be handled in tablegen
       Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
@@ -3501,7 +3500,7 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
                       MachinePointerInfo(SV));
 }
 
-static std::tuple<MemSDNode *, uint32_t>
+static std::pair<MemSDNode *, uint32_t>
 convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
   SDValue Chain = N->getOperand(0);
   SDValue BasePtr = N->getOperand(1);
@@ -3526,16 +3525,14 @@ convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
   uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
   uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
 
-  for (unsigned I :
-       llvm::reverse(llvm::seq<unsigned>(0, ResVT.getVectorNumElements()))) {
-    assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
-           "Mask elements must be constants");
-    // We technically only want to do this shift for every iteration *but* the
-    // first, but in the first iteration NewMask is 0, so this shift is a
-    // no-op.
+  for (SDValue Op : llvm::reverse(Mask->ops())) {
+    // We technically only want to do this shift for every
+    // iteration *but* the first, but in the first iteration NewMask is 0, so
+    // this shift is a no-op.
     UsedBytesMask <<= ElementSizeInBytes;
 
-    if (Mask->getConstantOperandVal(I) != 0)
+    // Mask elements must be constants.
+    if (Op->getAsZExtVal() != 0)
       UsedBytesMask |= ElementMask;
   }
 
@@ -3581,11 +3578,8 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
 
   // If we have a masked load, convert it to a normal load now
   std::optional<uint32_t> UsedBytesMask = std::nullopt;
-  if (LD->getOpcode() == ISD::MLOAD) {
-    auto Result = convertMLOADToLoadWithUsedBytesMask(LD, DAG);
-    LD = std::get<0>(Result);
-    UsedBytesMask = std::get<1>(Result);
-  }
+  if (LD->getOpcode() == ISD::MLOAD)
+    std::tie(LD, UsedBytesMask) = convertMLOADToLoadWithUsedBytesMask(LD, DAG);
 
   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index ee08d941fb5ea..00bb129cbf801 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -610,12 +610,9 @@ bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
   if (!VTy)
     return false;
 
-  auto *ScalarTy = VTy->getScalarType();
-  if ((ScalarTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
-      (ScalarTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4))
-    return true;
-
-  return false;
+  auto *ElemTy = VTy->getScalarType();
+  return (ElemTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
+         (ElemTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4);
 }
 
 bool NVPTXTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
>From 8bf4ed6b245a6cff60a741ba1c62a3c884b20762 Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Tue, 21 Oct 2025 19:03:55 +0000
Subject: [PATCH 7/9] More review feedback
---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 6a5ea30ea0a0a..8c86912dc95a4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3525,7 +3525,7 @@ convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
   uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
   uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
 
-  for (SDValue Op : llvm::reverse(Mask->ops())) {
+  for (SDValue Op : reverse(Mask->ops())) {
     // We technically only want to do this shift for every
     // iteration *but* the first, but in the first iteration NewMask is 0, so
     // this shift is a no-op.
@@ -3714,8 +3714,7 @@ SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
   // will validate alignment. Therefore, we do not need to special case handle
   // them here.
   EVT VT = Op.getValueType();
-  if (NVPTX::isPackedVectorTy(VT) &&
-      (VT != MVT::v2f32 || STI.hasF32x2Instructions())) {
+  if (NVPTX::isPackedVectorTy(VT)) {
     auto Result =
         convertMLOADToLoadWithUsedBytesMask(cast<MemSDNode>(Op.getNode()), DAG);
     MemSDNode *LD = std::get<0>(Result);
@@ -5656,7 +5655,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
     Operands.push_back(DCI.DAG.getIntPtrConstant(
         cast<LoadSDNode>(LD)->getExtensionType(), DL));
     break;
-  // TODO do we need to support MLoadV1 here?
   case NVPTXISD::LoadV2:
     OldNumOutputs = 2;
     Opcode = NVPTXISD::LoadV4;
>From 7315fbd119f469614f364e07411c1ff67183437f Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Tue, 21 Oct 2025 19:24:59 +0000
Subject: [PATCH 8/9] Fix ARM TTI
---
 llvm/lib/Target/ARM/ARMTargetTransformInfo.h | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 401f5ea044dab..9a940cb3453b7 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -186,11 +186,15 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
 
   bool isProfitableLSRChainElement(Instruction *I) const override;
 
-  bool isLegalMaskedLoad(Type *DataTy, Align Alignment, unsigned AddressSpace,
-                         TTI::MaskKind MaskKind) const override;
-
-  bool isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
-                          TTI::MaskKind MaskKind) const override {
+  bool
+  isLegalMaskedLoad(Type *DataTy, Align Alignment, unsigned AddressSpace,
+                    TTI::MaskKind MaskKind =
+                        TTI::MaskKind::VariableOrConstantMask) const override;
+
+  bool
+  isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
+                     TTI::MaskKind MaskKind =
+                         TTI::MaskKind::VariableOrConstantMask) const override {
     return isLegalMaskedLoad(DataTy, Alignment, AddressSpace, MaskKind);
   }
 
>From 95c335cf3dcf0ebb501521f9f00a9058e1e7fecf Mon Sep 17 00:00:00 2001
From: Drew Kersnar <dkersnar at nvidia.com>
Date: Wed, 22 Oct 2025 17:04:21 +0000
Subject: [PATCH 9/9] Make fixes based on recent TOT changes, adjust tests,
 expand LoadV8 unpacking mov handling for v2i32 packed types
---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 12 ++--
 .../Scalar/ScalarizeMaskedMemIntrin.cpp       |  4 +-
 llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll  | 16 ++---
 .../NVPTX/machinelicm-no-preheader.mir        | 12 ++--
 .../test/CodeGen/NVPTX/masked-load-vectors.ll | 72 +++++++++----------
 .../NVPTX/masked-store-variable-mask.ll       |  4 +-
 .../CodeGen/NVPTX/masked-store-vectors-256.ll | 47 ++++++------
 7 files changed, 81 insertions(+), 86 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 8c86912dc95a4..73b2e930b7b8e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5660,9 +5660,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
     Opcode = NVPTXISD::LoadV4;
     break;
   case NVPTXISD::LoadV4:
-    // V8 is only supported for f32. Don't forget, we're not changing the load
-    // size here. This is already a 256-bit load.
-    if (ElementVT != MVT::v2f32)
+    // V8 is only supported for f32/i32. Don't forget, we're not changing the
+    // load size here. This is already a 256-bit load.
+    if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
       return SDValue();
     OldNumOutputs = 4;
     Opcode = NVPTXISD::LoadV8;
@@ -5737,9 +5737,9 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
     Opcode = NVPTXISD::StoreV4;
     break;
   case NVPTXISD::StoreV4:
-    // V8 is only supported for f32. Don't forget, we're not changing the store
-    // size here. This is already a 256-bit store.
-    if (ElementVT != MVT::v2f32)
+    // V8 is only supported for f32/i32. Don't forget, we're not changing the
+    // store size here. This is already a 256-bit store.
+    if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
       return SDValue();
     Opcode = NVPTXISD::StoreV8;
     break;
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index ff78dd172f38d..b7b08ae61ec52 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1124,7 +1124,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getType(), CI->getParamAlign(0).valueOrOne(),
               cast<PointerType>(CI->getArgOperand(0)->getType())
                   ->getAddressSpace(),
-              isConstantIntVector(CI->getArgOperand(2))
+              isConstantIntVector(CI->getArgOperand(1))
                   ? TTI::MaskKind::ConstantMask
                   : TTI::MaskKind::VariableOrConstantMask))
         return false;
@@ -1136,7 +1136,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
               CI->getParamAlign(1).valueOrOne(),
               cast<PointerType>(CI->getArgOperand(1)->getType())
                   ->getAddressSpace(),
-              isConstantIntVector(CI->getArgOperand(3))
+              isConstantIntVector(CI->getArgOperand(2))
                   ? TTI::MaskKind::ConstantMask
                   : TTI::MaskKind::VariableOrConstantMask))
         return false;
diff --git a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
index 3fac29f74125b..d219493d2b31b 100644
--- a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
+++ b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll
@@ -346,19 +346,15 @@ define i32 @ld_global_v8i32(ptr addrspace(1) %ptr) {
 ; SM100-LABEL: ld_global_v8i32(
 ; SM100:       {
 ; SM100-NEXT:    .reg .b32 %r<16>;
-; SM100-NEXT:    .reg .b64 %rd<6>;
+; SM100-NEXT:    .reg .b64 %rd<2>;
 ; SM100-EMPTY:
 ; SM100-NEXT:  // %bb.0:
 ; SM100-NEXT:    ld.param.b64 %rd1, [ld_global_v8i32_param_0];
-; SM100-NEXT:    ld.global.nc.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
-; SM100-NEXT:    mov.b64 {%r1, %r2}, %rd5;
-; SM100-NEXT:    mov.b64 {%r3, %r4}, %rd4;
-; SM100-NEXT:    mov.b64 {%r5, %r6}, %rd3;
-; SM100-NEXT:    mov.b64 {%r7, %r8}, %rd2;
-; SM100-NEXT:    add.s32 %r9, %r7, %r8;
-; SM100-NEXT:    add.s32 %r10, %r5, %r6;
-; SM100-NEXT:    add.s32 %r11, %r3, %r4;
-; SM100-NEXT:    add.s32 %r12, %r1, %r2;
+; SM100-NEXT:    ld.global.nc.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
+; SM100-NEXT:    add.s32 %r9, %r1, %r2;
+; SM100-NEXT:    add.s32 %r10, %r3, %r4;
+; SM100-NEXT:    add.s32 %r11, %r5, %r6;
+; SM100-NEXT:    add.s32 %r12, %r7, %r8;
 ; SM100-NEXT:    add.s32 %r13, %r9, %r10;
 ; SM100-NEXT:    add.s32 %r14, %r11, %r12;
 ; SM100-NEXT:    add.s32 %r15, %r13, %r14;
diff --git a/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir b/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
index 0b2d85600a2ef..4be91dfc60c6a 100644
--- a/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
+++ b/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir
@@ -26,10 +26,10 @@ body:             |
   ; CHECK: bb.0.entry:
   ; CHECK-NEXT:   successors: %bb.2(0x30000000), %bb.3(0x50000000)
   ; CHECK-NEXT: {{  $}}
-  ; CHECK-NEXT:   [[LD_i32_:%[0-9]+]]:b32 = LD_i32 0, 0, 101, 3, 32, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
-  ; CHECK-NEXT:   [[LD_i64_:%[0-9]+]]:b64 = LD_i64 0, 0, 101, 3, 64, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
+  ; CHECK-NEXT:   [[LD_i32_:%[0-9]+]]:b32 = LD_i32 0, 0, 101, 3, 32, -1, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
+  ; CHECK-NEXT:   [[LD_i64_:%[0-9]+]]:b64 = LD_i64 0, 0, 101, 3, 64, -1, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
   ; CHECK-NEXT:   [[ADD64ri:%[0-9]+]]:b64 = nuw ADD64ri killed [[LD_i64_]], 2
-  ; CHECK-NEXT:   [[LD_i32_1:%[0-9]+]]:b32 = LD_i32 0, 0, 1, 3, 32, [[ADD64ri]], 0
+  ; CHECK-NEXT:   [[LD_i32_1:%[0-9]+]]:b32 = LD_i32 0, 0, 1, 3, 32, -1, [[ADD64ri]], 0
   ; CHECK-NEXT:   [[SETP_i32ri:%[0-9]+]]:b1 = SETP_i32ri [[LD_i32_]], 0, 0
   ; CHECK-NEXT:   CBranch killed [[SETP_i32ri]], %bb.2
   ; CHECK-NEXT: {{  $}}
@@ -54,10 +54,10 @@ body:             |
   bb.0.entry:
     successors: %bb.2(0x30000000), %bb.1(0x50000000)
 
-    %5:b32 = LD_i32 0, 0, 101, 3, 32, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
-    %6:b64 = LD_i64 0, 0, 101, 3, 64, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
+    %5:b32 = LD_i32 0, 0, 101, 3, 32, -1, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101)
+    %6:b64 = LD_i64 0, 0, 101, 3, 64, -1, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101)
     %0:b64 = nuw ADD64ri killed %6, 2
-    %1:b32 = LD_i32 0, 0, 1, 3, 32, %0, 0
+    %1:b32 = LD_i32 0, 0, 1, 3, 32, -1, %0, 0
     %7:b1 = SETP_i32ri %5, 0, 0
     CBranch killed %7, %bb.2
     GOTO %bb.1
diff --git a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
index 5b4136ec1307e..7c7c51be9567d 100644
--- a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
+++ b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll
@@ -40,8 +40,8 @@ define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_param_1];
 ; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
 ; SM100-NEXT:    ret;
-  %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) %a, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
-  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 32 %a, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
   ret void
 }
 
@@ -93,8 +93,8 @@ define void @global_16xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; SM100-NEXT:    st.global.b16 [%rd2+28], %rs1;
 ; SM100-NEXT:    st.global.b16 [%rd2+30], %rs2;
 ; SM100-NEXT:    ret;
-  %a.load = tail call <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1) %a, i32 32, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>, <16 x i16> poison)
-  tail call void @llvm.masked.store.v16i16.p1(<16 x i16> %a.load, ptr addrspace(1) %b, i32 32, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>)
+  %a.load = tail call <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1) align 32 %a, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>, <16 x i16> poison)
+  tail call void @llvm.masked.store.v16i16.p1(<16 x i16> %a.load, ptr addrspace(1) align 32 %b, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true>)
   ret void
 }
 
@@ -114,8 +114,8 @@ define void @global_8xi32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; CHECK-NEXT:    st.global.b32 [%rd2+8], %r2;
 ; CHECK-NEXT:    st.global.b32 [%rd2+28], %r3;
 ; CHECK-NEXT:    ret;
-  %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) %a, i32 16, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
-  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 16, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 16 %a, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison)
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 16 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
   ret void
 }
 
@@ -150,8 +150,8 @@ define void @global_8xi32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_invariant_param_1];
 ; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
 ; SM100-NEXT:    ret;
-  %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) %a, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison), !invariant.load !0
-  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 32 %a, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x i32> poison), !invariant.load !0
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
   ret void
 }
 
@@ -170,8 +170,8 @@ define void @global_2xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; CHECK-NEXT:    mov.b32 {%rs1, _}, %r1;
 ; CHECK-NEXT:    st.global.b16 [%rd2], %rs1;
 ; CHECK-NEXT:    ret;
-  %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) %a, i32 4, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
-  tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) %b, i32 4, <2 x i1> <i1 true, i1 false>)
+  %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 4 %a, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
+  tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> <i1 true, i1 false>)
   ret void
 }
 
@@ -190,8 +190,8 @@ define void @global_2xi16_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; CHECK-NEXT:    mov.b32 {%rs1, _}, %r1;
 ; CHECK-NEXT:    st.global.b16 [%rd2], %rs1;
 ; CHECK-NEXT:    ret;
-  %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) %a, i32 4, <2 x i1> <i1 true, i1 false>, <2 x i16> poison), !invariant.load !0
-  tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) %b, i32 4, <2 x i1> <i1 true, i1 false>)
+  %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 4 %a, <2 x i1> <i1 true, i1 false>, <2 x i16> poison), !invariant.load !0
+  tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> <i1 true, i1 false>)
   ret void
 }
 
@@ -207,8 +207,8 @@ define void @global_2xi16_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; CHECK-NEXT:    ld.param.b64 %rd2, [global_2xi16_no_align_param_1];
 ; CHECK-NEXT:    st.global.b16 [%rd2], %rs1;
 ; CHECK-NEXT:    ret;
-  %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) %a, i32 2, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
-  tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) %b, i32 4, <2 x i1> <i1 true, i1 false>)
+  %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 2 %a, <2 x i1> <i1 true, i1 false>, <2 x i16> poison)
+  tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> <i1 true, i1 false>)
   ret void
 }
 
@@ -227,8 +227,8 @@ define void @global_4xi8(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; CHECK-NEXT:    prmt.b32 %r2, %r1, 0, 0x7772U;
 ; CHECK-NEXT:    st.global.b8 [%rd2+2], %r2;
 ; CHECK-NEXT:    ret;
-  %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) %a, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
-  tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) %b, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 4 %a, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
+  tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
   ret void
 }
 
@@ -247,8 +247,8 @@ define void @global_4xi8_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; CHECK-NEXT:    prmt.b32 %r2, %r1, 0, 0x7772U;
 ; CHECK-NEXT:    st.global.b8 [%rd2+2], %r2;
 ; CHECK-NEXT:    ret;
-  %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) %a, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison), !invariant.load !0
-  tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) %b, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 4 %a, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison), !invariant.load !0
+  tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
   ret void
 }
 
@@ -266,8 +266,8 @@ define void @global_4xi8_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; CHECK-NEXT:    st.global.b8 [%rd2], %rs1;
 ; CHECK-NEXT:    st.global.b8 [%rd2+2], %rs2;
 ; CHECK-NEXT:    ret;
-  %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) %a, i32 2, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
-  tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) %b, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 2 %a, <4 x i1> <i1 true, i1 false, i1 true, i1 false>, <4 x i8> poison)
+  tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
   ret void
 }
 
@@ -299,8 +299,8 @@ define void @global_2xf32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; SM100-NEXT:    mov.b64 {%r1, _}, %rd2;
 ; SM100-NEXT:    st.global.b32 [%rd3], %r1;
 ; SM100-NEXT:    ret;
-  %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) %a, i32 8, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
-  tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) %b, i32 8, <2 x i1> <i1 true, i1 false>)
+  %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 8 %a, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
+  tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> <i1 true, i1 false>)
   ret void
 }
 
@@ -331,8 +331,8 @@ define void @global_2xf32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; SM100-NEXT:    mov.b64 {%r1, _}, %rd2;
 ; SM100-NEXT:    st.global.b32 [%rd3], %r1;
 ; SM100-NEXT:    ret;
-  %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) %a, i32 8, <2 x i1> <i1 true, i1 false>, <2 x float> poison), !invariant.load !0
-  tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) %b, i32 8, <2 x i1> <i1 true, i1 false>)
+  %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 8 %a, <2 x i1> <i1 true, i1 false>, <2 x float> poison), !invariant.load !0
+  tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> <i1 true, i1 false>)
   ret void
 }
 
@@ -348,19 +348,19 @@ define void @global_2xf32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; CHECK-NEXT:    ld.param.b64 %rd2, [global_2xf32_no_align_param_1];
 ; CHECK-NEXT:    st.global.b32 [%rd2], %r1;
 ; CHECK-NEXT:    ret;
-  %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) %a, i32 4, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
-  tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) %b, i32 8, <2 x i1> <i1 true, i1 false>)
+  %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 4 %a, <2 x i1> <i1 true, i1 false>, <2 x float> poison)
+  tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> <i1 true, i1 false>)
   ret void
 }
 
-declare <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1), i32, <8 x i1>, <8 x i32>)
-declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), i32, <8 x i1>)
-declare <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1), i32, <16 x i1>, <16 x i16>)
-declare void @llvm.masked.store.v16i16.p1(<16 x i16>, ptr addrspace(1), i32, <16 x i1>)
-declare <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1), i32, <2 x i1>, <2 x i16>)
-declare void @llvm.masked.store.v2i16.p1(<2 x i16>, ptr addrspace(1), i32, <2 x i1>)
-declare <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1), i32, <4 x i1>, <4 x i8>)
-declare void @llvm.masked.store.v4i8.p1(<4 x i8>, ptr addrspace(1), i32, <4 x i1>)
-declare <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1), i32, <2 x i1>, <2 x float>)
-declare void @llvm.masked.store.v2f32.p1(<2 x float>, ptr addrspace(1), i32, <2 x i1>)
+declare <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1), <8 x i1>, <8 x i32>)
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), <8 x i1>)
+declare <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1), <16 x i1>, <16 x i16>)
+declare void @llvm.masked.store.v16i16.p1(<16 x i16>, ptr addrspace(1), <16 x i1>)
+declare <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1), <2 x i1>, <2 x i16>)
+declare void @llvm.masked.store.v2i16.p1(<2 x i16>, ptr addrspace(1), <2 x i1>)
+declare <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1), <4 x i1>, <4 x i8>)
+declare void @llvm.masked.store.v4i8.p1(<4 x i8>, ptr addrspace(1), <4 x i1>)
+declare <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1), <2 x i1>, <2 x float>)
+declare void @llvm.masked.store.v2f32.p1(<2 x float>, ptr addrspace(1), <2 x i1>)
 !0 = !{}
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
index 7d8f65b25bb02..9f23acaf93bc8 100644
--- a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
+++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
@@ -49,8 +49,8 @@ define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x
 ; CHECK-NEXT:  $L__BB0_8: // %else6
 ; CHECK-NEXT:    ret;
   %a.load = load <4 x i64>, ptr addrspace(1) %a
-  tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> %mask)
+  tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> %mask)
   ret void
 }
 
-declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), <4 x i1>)
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
index 0935bf80b04be..feb7b7e0a0b39 100644
--- a/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
+++ b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll
@@ -34,7 +34,7 @@ define void @generic_8xi32(ptr %a, ptr %b) {
 ; CHECK-NEXT:    st.b32 [%rd2+28], %r4;
 ; CHECK-NEXT:    ret;
   %a.load = load <8 x i32>, ptr %a
-  tail call void @llvm.masked.store.v8i32.p0(<8 x i32> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  tail call void @llvm.masked.store.v8i32.p0(<8 x i32> %a.load, ptr align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
   ret void
 }
 
@@ -52,7 +52,7 @@ define void @generic_4xi64(ptr %a, ptr %b) {
 ; CHECK-NEXT:    st.b64 [%rd6+16], %rd2;
 ; CHECK-NEXT:    ret;
   %a.load = load <4 x i64>, ptr %a
-  tail call void @llvm.masked.store.v4i64.p0(<4 x i64> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  tail call void @llvm.masked.store.v4i64.p0(<4 x i64> %a.load, ptr align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
   ret void
 }
 
@@ -72,7 +72,7 @@ define void @generic_8xfloat(ptr %a, ptr %b) {
 ; CHECK-NEXT:    st.b32 [%rd2+28], %r4;
 ; CHECK-NEXT:    ret;
   %a.load = load <8 x float>, ptr %a
-  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %a.load, ptr %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %a.load, ptr align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
   ret void
 }
 
@@ -90,7 +90,7 @@ define void @generic_4xdouble(ptr %a, ptr %b) {
 ; CHECK-NEXT:    st.b64 [%rd6+16], %rd2;
 ; CHECK-NEXT:    ret;
   %a.load = load <4 x double>, ptr %a
-  tail call void @llvm.masked.store.v4f64.p0(<4 x double> %a.load, ptr %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  tail call void @llvm.masked.store.v4f64.p0(<4 x double> %a.load, ptr align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
   ret void
 }
 
@@ -124,7 +124,7 @@ define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
 ; SM100-NEXT:    ret;
   %a.load = load <8 x i32>, ptr addrspace(1) %a
-  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
   ret void
 }
 
@@ -153,7 +153,7 @@ define void @global_4xi64(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; SM100-NEXT:    st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
 ; SM100-NEXT:    ret;
   %a.load = load <4 x i64>, ptr addrspace(1) %a
-  tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
   ret void
 }
 
@@ -185,7 +185,7 @@ define void @global_8xfloat(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8};
 ; SM100-NEXT:    ret;
   %a.load = load <8 x float>, ptr addrspace(1) %a
-  tail call void @llvm.masked.store.v8f32.p1(<8 x float> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
+  tail call void @llvm.masked.store.v8f32.p1(<8 x float> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 true>)
   ret void
 }
 
@@ -214,7 +214,7 @@ define void @global_4xdouble(ptr addrspace(1) %a, ptr addrspace(1) %b) {
 ; SM100-NEXT:    st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _};
 ; SM100-NEXT:    ret;
   %a.load = load <4 x double>, ptr addrspace(1) %a
-  tail call void @llvm.masked.store.v4f64.p1(<4 x double> %a.load, ptr addrspace(1) %b, i32 32, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  tail call void @llvm.masked.store.v4f64.p1(<4 x double> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
   ret void
 }
 
@@ -236,17 +236,16 @@ define void @global_8xi32_all_mask_on(ptr addrspace(1) %a, ptr addrspace(1) %b)
 ;
 ; SM100-LABEL: global_8xi32_all_mask_on(
 ; SM100:       {
-; SM100-NEXT:    .reg .b32 %r<9>;
-; SM100-NEXT:    .reg .b64 %rd<3>;
+; SM100-NEXT:    .reg .b64 %rd<7>;
 ; SM100-EMPTY:
 ; SM100-NEXT:  // %bb.0:
 ; SM100-NEXT:    ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0];
-; SM100-NEXT:    ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1];
-; SM100-NEXT:    ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1];
-; SM100-NEXT:    st.global.v8.b32 [%rd2], {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8};
+; SM100-NEXT:    ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd6, [global_8xi32_all_mask_on_param_1];
+; SM100-NEXT:    st.global.v4.b64 [%rd6], {%rd2, %rd3, %rd4, %rd5};
 ; SM100-NEXT:    ret;
   %a.load = load <8 x i32>, ptr addrspace(1) %a
-  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
   ret void
 }
 
@@ -258,7 +257,7 @@ define void @global_8xi32_all_mask_off(ptr addrspace(1) %a, ptr addrspace(1) %b)
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ret;
   %a.load = load <8 x i32>, ptr addrspace(1) %a
-  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) %b, i32 32, <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>)
+  tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>)
   ret void
 }
 
@@ -304,16 +303,16 @@ define void @vectorizerOutput(ptr addrspace(1) %in, ptr addrspace(1) %out) {
   %7 = insertelement <8 x i32> %6, i32 poison, i32 5
   %8 = insertelement <8 x i32> %7, i32 poison, i32 6
   %9 = insertelement <8 x i32> %8, i32 poison, i32 7
-  call void @llvm.masked.store.v8i32.p1(<8 x i32> %9, ptr addrspace(1) %out, i32 32, <8 x i1> <i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false>)
+  call void @llvm.masked.store.v8i32.p1(<8 x i32> %9, ptr addrspace(1) align 32 %out, <8 x i1> <i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false>)
   ret void
 }
 
-declare void @llvm.masked.store.v8i32.p0(<8 x i32>, ptr, i32, <8 x i1>)
-declare void @llvm.masked.store.v4i64.p0(<4 x i64>, ptr, i32, <4 x i1>)
-declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr, i32, <8 x i1>)
-declare void @llvm.masked.store.v4f64.p0(<4 x double>, ptr, i32, <4 x i1>)
+declare void @llvm.masked.store.v8i32.p0(<8 x i32>, ptr, <8 x i1>)
+declare void @llvm.masked.store.v4i64.p0(<4 x i64>, ptr, <4 x i1>)
+declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr, <8 x i1>)
+declare void @llvm.masked.store.v4f64.p0(<4 x double>, ptr, <4 x i1>)
 
-declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), i32, <8 x i1>)
-declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), i32, <4 x i1>)
-declare void @llvm.masked.store.v8f32.p1(<8 x float>, ptr addrspace(1), i32, <8 x i1>)
-declare void @llvm.masked.store.v4f64.p1(<4 x double>, ptr addrspace(1), i32, <4 x i1>)
+declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), <8 x i1>)
+declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), <4 x i1>)
+declare void @llvm.masked.store.v8f32.p1(<8 x float>, ptr addrspace(1), <8 x i1>)
+declare void @llvm.masked.store.v4f64.p1(<4 x double>, ptr addrspace(1), <4 x i1>)
    
    
More information about the llvm-commits
mailing list